训练智能体#

本页面简要介绍了如何为 Gymnasium 环境训练智能体(agent),特别是我们将使用基于表格的 Q-learning 来解决 Blackjack v1 环境。有关此教程的完整版本以及更多其他环境和算法的训练教程,请参阅此处。在阅读本页之前,请先阅读基本用法。在我们实现任何代码之前,这里是 Blackjack 和 Q-learning 的概述。

Blackjack 是最受欢迎的赌场纸牌游戏之一,也因在某些条件下可被击败而臭名昭著。这款游戏的版本使用无限副牌(我们替换抽牌),因此在我们模拟的游戏中数牌不是一个可行的策略。观察值是一个元组,包括玩家当前的总点数、庄家面朝上的牌的值以及一个布尔值,表示玩家是否持有可用的牌。智能体可以在两个动作之间选择:站立(0)意味着玩家不再拿牌,击打(1)意味着玩家将再拿一张牌。要赢,你的牌的总和应该大于庄家的牌的总和且不超过21。如果玩家选择站立或牌的总和超过21,游戏结束。完整的文档可以在 toy_text/blackjack 找到。

Q-learning 是一种由 Watkins 于 1989 年提出的无模型离策略学习算法,适用于具有离散动作空间的环境,并因其是第一个在一定条件下证明收敛到最优策略的强化学习算法而闻名。

执行动作#

在接收到第一个观察值之后,我们将使用 env.step(action) 函数与环境进行交互。该函数接受动作作为输入并在环境中执行它。因为这个动作会改变环境的状态,所以它会返回四个有用的变量给我们。这些是:

  • 下一个观测值:这是智能体在采取动作后将接收到的观察值。

  • 奖励:这是智能体在采取动作后将接收到的奖励。

  • 终止:这是一个布尔变量,指示环境是否已经因内部条件而终止(即结束)。

  • 截断:这是一个布尔变量,也指示情节是否由于提前截断而结束,即达到了时间限制。

  • 信息:这是一个可能包含有关环境的额外信息的字典。

下一个观测值(next observation)、奖励(reward)、终止(terminated)和截断(truncated)变量是不言自明的,但信息变量需要一些额外的解释。这个变量包含一个字典,可能有一些关于环境的额外信息,但在 Blackjack-v1 环境中你可以忽略它。例如,在雅达利(Atari)环境中,信息字典有 ale.lives 键,告诉我们智能体还剩下多少条命。如果智能体没有生命了,那么情节就结束了。

注意,在你的训练循环中调用 env.render() 不是好主意,因为渲染会大大减慢训练速度。相反,尝试构建一个额外的循环来评估和展示训练后的智能体。

构建智能体#

让我们来构建用于解决 Blackjack 的 Q-learning 智能体!我们需要一些函数来选择动作和更新智能体的动作值。为了确保智能体能探索环境,可能的解决方案是 epsilon-greedy 策略,在这种策略中,我们以 epsilon 的概率随机选择一个动作,以 1 - epsilon 的概率选择当前估值最高的贪婪动作。

from collections import defaultdict
import gymnasium as gym
import numpy as np


class BlackjackAgent:
    def __init__(
        self,
        env: gym.Env,
        learning_rate: float,
        initial_epsilon: float,
        epsilon_decay: float,
        final_epsilon: float,
        discount_factor: float = 0.95,
    ):
        """Initialize a Reinforcement Learning agent with an empty dictionary
        of state-action values (q_values), a learning rate and an epsilon.

        Args:
            env: The training environment
            learning_rate: The learning rate
            initial_epsilon: The initial epsilon value
            epsilon_decay: The decay for epsilon
            final_epsilon: The final epsilon value
            discount_factor: The discount factor for computing the Q-value
        """
        self.env = env
        self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))

        self.lr = learning_rate
        self.discount_factor = discount_factor

        self.epsilon = initial_epsilon
        self.epsilon_decay = epsilon_decay
        self.final_epsilon = final_epsilon

        self.training_error = []

    def get_action(self, obs: tuple[int, int, bool]) -> int:
        """
        Returns the best action with probability (1 - epsilon)
        otherwise a random action with probability epsilon to ensure exploration.
        """
        # with probability epsilon return a random action to explore the environment
        if np.random.random() < self.epsilon:
            return self.env.action_space.sample()
        # with probability (1 - epsilon) act greedily (exploit)
        else:
            return int(np.argmax(self.q_values[obs]))

    def update(
        self,
        obs: tuple[int, int, bool],
        action: int,
        reward: float,
        terminated: bool,
        next_obs: tuple[int, int, bool],
    ):
        """Updates the Q-value of an action."""
        future_q_value = (not terminated) * np.max(self.q_values[next_obs])
        temporal_difference = (
            reward + self.discount_factor * future_q_value - self.q_values[obs][action]
        )

        self.q_values[obs][action] = (
            self.q_values[obs][action] + self.lr * temporal_difference
        )
        self.training_error.append(temporal_difference)

    def decay_epsilon(self):
        self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)
        

训练智能体#

为了训练智能体,我们将让智能体一次玩一个情节(一个完整的游戏称为一个情节),然后在每个情节之后更新其 Q 值。智能体将不得不经历很多情节来充分探索环境。

# hyperparameters
learning_rate = 0.01
n_episodes = 100_000
start_epsilon = 1.0
epsilon_decay = start_epsilon / (n_episodes / 2)  # reduce the exploration over time
final_epsilon = 0.1
env = gym.make('Blackjack-v1', natural=False, sab=False)
agent = BlackjackAgent(
    env,
    learning_rate=learning_rate,
    initial_epsilon=start_epsilon,
    epsilon_decay=epsilon_decay,
    final_epsilon=final_epsilon,
)

备注

当前的超参数设置是为了快速训练一个像样的智能体。如果你想收敛到最优策略,可以尝试将 n_episodes 增加10倍,并降低学习率(例如,降到 0.001)。

from tqdm import tqdm

env = gym.make("Blackjack-v1", sab=False)
env = gym.wrappers.RecordEpisodeStatistics(env, buffer_length=n_episodes)

for episode in tqdm(range(n_episodes)):
    obs, info = env.reset()
    done = False

    # play one episode
    while not done:
        action = agent.get_action(obs)
        next_obs, reward, terminated, truncated, info = env.step(action)

        # update the agent
        agent.update(obs, action, reward, terminated, next_obs)

        # update if the environment is done and the current obs
        done = terminated or truncated
        obs = next_obs

    agent.decay_epsilon()
0%|          | 0/100000 [00:00<?, ?it/s]
1%|          | 754/100000 [00:00<00:13, 7530.03it/s]
2%|▏         | 1508/100000 [00:00<00:17, 5739.34it/s]
2%|▏         | 2108/100000 [00:00<00:17, 5471.94it/s]
3%|▎         | 2796/100000 [00:00<00:16, 5953.03it/s]
3%|▎         | 3406/100000 [00:00<00:19, 4842.35it/s]
4%|▍         | 3924/100000 [00:00<00:20, 4632.97it/s]
4%|▍         | 4482/100000 [00:00<00:19, 4887.79it/s]
5%|▌         | 5164/100000 [00:00<00:17, 5423.42it/s]
6%|▌         | 5915/100000 [00:01<00:15, 6014.81it/s]
7%|▋         | 6547/100000 [00:01<00:15, 6101.59it/s]
7%|▋         | 7174/100000 [00:01<00:15, 6149.34it/s]
8%|▊         | 7921/100000 [00:01<00:14, 6537.31it/s]
9%|▊         | 8584/100000 [00:01<00:14, 6452.94it/s]
9%|▉         | 9324/100000 [00:01<00:13, 6729.67it/s]
10%|█         | 10083/100000 [00:01<00:12, 6981.93it/s]
11%|█         | 10786/100000 [00:01<00:14, 6047.63it/s]
11%|█▏        | 11416/100000 [00:01<00:14, 5953.18it/s]
12%|█▏        | 12029/100000 [00:02<00:14, 5915.96it/s]
13%|█▎        | 12665/100000 [00:02<00:14, 6034.00it/s]
13%|█▎        | 13421/100000 [00:02<00:13, 6465.77it/s]
14%|█▍        | 14183/100000 [00:02<00:12, 6797.93it/s]
15%|█▍        | 14871/100000 [00:02<00:13, 6268.91it/s]
16%|█▌        | 15634/100000 [00:02<00:12, 6643.78it/s]
16%|█▋        | 16405/100000 [00:02<00:12, 6945.23it/s]
17%|█▋        | 17110/100000 [00:02<00:12, 6706.25it/s]
18%|█▊        | 17789/100000 [00:02<00:15, 5377.48it/s]
18%|█▊        | 18482/100000 [00:03<00:14, 5755.47it/s]
19%|█▉        | 19170/100000 [00:03<00:13, 6046.13it/s]
20%|█▉        | 19808/100000 [00:03<00:13, 5901.42it/s]
21%|██        | 20537/100000 [00:03<00:12, 6275.74it/s]
21%|██▏       | 21272/100000 [00:03<00:11, 6573.98it/s]
22%|██▏       | 21946/100000 [00:03<00:12, 6272.40it/s]
23%|██▎       | 22674/100000 [00:03<00:11, 6551.46it/s]
23%|██▎       | 23431/100000 [00:03<00:11, 6838.60it/s]
24%|██▍       | 24125/100000 [00:03<00:11, 6852.54it/s]
25%|██▍       | 24866/100000 [00:04<00:10, 7014.75it/s]
26%|██▌       | 25608/100000 [00:04<00:10, 7132.97it/s]
26%|██▋       | 26326/100000 [00:04<00:10, 6988.51it/s]
27%|██▋       | 27064/100000 [00:04<00:10, 7101.83it/s]
28%|██▊       | 27822/100000 [00:04<00:09, 7240.43it/s]
29%|██▊       | 28549/100000 [00:04<00:09, 7199.42it/s]
29%|██▉       | 29302/100000 [00:04<00:09, 7293.35it/s]
30%|███       | 30052/100000 [00:04<00:09, 7352.89it/s]
31%|███       | 30789/100000 [00:04<00:10, 6656.38it/s]
32%|███▏      | 31508/100000 [00:04<00:10, 6803.22it/s]
32%|███▏      | 32275/100000 [00:05<00:09, 7046.59it/s]
33%|███▎      | 33037/100000 [00:05<00:09, 7209.26it/s]
34%|███▍      | 33765/100000 [00:05<00:10, 6195.42it/s]
34%|███▍      | 34414/100000 [00:05<00:12, 5292.06it/s]
35%|███▌      | 35087/100000 [00:05<00:11, 5637.23it/s]
36%|███▌      | 35852/100000 [00:05<00:10, 6150.81it/s]
37%|███▋      | 36592/100000 [00:05<00:09, 6485.10it/s]
37%|███▋      | 37290/100000 [00:05<00:09, 6621.01it/s]
38%|███▊      | 37973/100000 [00:05<00:09, 6442.97it/s]
39%|███▊      | 38633/100000 [00:06<00:11, 5539.65it/s]
39%|███▉      | 39371/100000 [00:06<00:10, 6007.77it/s]
40%|████      | 40052/100000 [00:06<00:09, 6221.71it/s]
41%|████      | 40698/100000 [00:06<00:09, 6181.18it/s]
41%|████▏     | 41450/100000 [00:06<00:08, 6554.35it/s]
42%|████▏     | 42158/100000 [00:06<00:08, 6703.73it/s]
43%|████▎     | 42839/100000 [00:06<00:09, 5927.34it/s]
44%|████▎     | 43598/100000 [00:06<00:08, 6368.63it/s]
44%|████▍     | 44257/100000 [00:07<00:09, 6111.76it/s]
45%|████▍     | 44974/100000 [00:07<00:08, 6396.07it/s]
46%|████▌     | 45705/100000 [00:07<00:08, 6651.41it/s]
46%|████▋     | 46469/100000 [00:07<00:07, 6931.94it/s]
47%|████▋     | 47172/100000 [00:07<00:07, 6958.35it/s]
48%|████▊     | 47875/100000 [00:07<00:08, 6236.98it/s]
49%|████▊     | 48517/100000 [00:07<00:08, 6230.37it/s]
49%|████▉     | 49282/100000 [00:07<00:07, 6623.89it/s]
50%|█████     | 50053/100000 [00:07<00:07, 6931.19it/s]
51%|█████     | 50757/100000 [00:07<00:07, 6807.03it/s]
52%|█████▏    | 51517/100000 [00:08<00:06, 7034.81it/s]
52%|█████▏    | 52288/100000 [00:08<00:06, 7229.72it/s]
53%|█████▎    | 53042/100000 [00:08<00:06, 7319.02it/s]
54%|█████▍    | 53778/100000 [00:08<00:06, 7183.98it/s]
55%|█████▍    | 54526/100000 [00:08<00:06, 7267.82it/s]
55%|█████▌    | 55290/100000 [00:08<00:06, 7375.36it/s]
56%|█████▌    | 56033/100000 [00:08<00:05, 7390.11it/s]
57%|█████▋    | 56774/100000 [00:08<00:06, 6769.90it/s]
57%|█████▋    | 57462/100000 [00:08<00:06, 6617.02it/s]
58%|█████▊    | 58132/100000 [00:09<00:06, 6290.21it/s]
59%|█████▉    | 58852/100000 [00:09<00:06, 6539.33it/s]
60%|█████▉    | 59578/100000 [00:09<00:05, 6741.77it/s]
60%|██████    | 60259/100000 [00:09<00:05, 6624.03it/s]
61%|██████    | 60926/100000 [00:09<00:06, 5605.35it/s]
62%|██████▏   | 61515/100000 [00:09<00:07, 5326.40it/s]
62%|██████▏   | 62068/100000 [00:09<00:07, 5054.04it/s]
63%|██████▎   | 62796/100000 [00:09<00:06, 5622.80it/s]
64%|██████▎   | 63557/100000 [00:09<00:05, 6154.24it/s]
64%|██████▍   | 64193/100000 [00:10<00:05, 6020.36it/s]
65%|██████▍   | 64879/100000 [00:10<00:05, 6252.46it/s]
66%|██████▌   | 65651/100000 [00:10<00:05, 6666.78it/s]
66%|██████▋   | 66378/100000 [00:10<00:04, 6839.59it/s]
67%|██████▋   | 67071/100000 [00:10<00:04, 6833.55it/s]
68%|██████▊   | 67761/100000 [00:10<00:05, 6129.90it/s]
68%|██████▊   | 68391/100000 [00:10<00:05, 6123.04it/s]
69%|██████▉   | 69016/100000 [00:10<00:05, 5989.17it/s]
70%|██████▉   | 69638/100000 [00:10<00:05, 6051.34it/s]
70%|███████   | 70404/100000 [00:11<00:04, 6509.21it/s]
71%|███████   | 71162/100000 [00:11<00:04, 6817.79it/s]
72%|███████▏  | 71850/100000 [00:11<00:04, 6557.41it/s]
73%|███████▎  | 72589/100000 [00:11<00:04, 6794.96it/s]
73%|███████▎  | 73353/100000 [00:11<00:03, 7039.74it/s]
74%|███████▍  | 74128/100000 [00:11<00:03, 7245.89it/s]
75%|███████▍  | 74857/100000 [00:11<00:03, 7208.67it/s]
76%|███████▌  | 75598/100000 [00:11<00:03, 7265.31it/s]
76%|███████▋  | 76327/100000 [00:11<00:03, 7091.88it/s]
77%|███████▋  | 77039/100000 [00:11<00:03, 6945.54it/s]
78%|███████▊  | 77736/100000 [00:12<00:03, 6466.13it/s]
78%|███████▊  | 78446/100000 [00:12<00:03, 6641.05it/s]
79%|███████▉  | 79117/100000 [00:12<00:03, 5734.15it/s]
80%|███████▉  | 79715/100000 [00:12<00:03, 5677.56it/s]
80%|████████  | 80323/100000 [00:12<00:03, 5783.74it/s]
81%|████████  | 80926/100000 [00:12<00:03, 5851.10it/s]
82%|████████▏ | 81691/100000 [00:12<00:02, 6358.88it/s]
82%|████████▏ | 82350/100000 [00:12<00:02, 6425.07it/s]
83%|████████▎ | 83009/100000 [00:12<00:02, 6471.31it/s]
84%|████████▎ | 83662/100000 [00:13<00:02, 6184.52it/s]
84%|████████▍ | 84405/100000 [00:13<00:02, 6538.21it/s]
85%|████████▌ | 85147/100000 [00:13<00:02, 6792.97it/s]
86%|████████▌ | 85899/100000 [00:13<00:02, 7004.47it/s]
87%|████████▋ | 86625/100000 [00:13<00:01, 7078.58it/s]
87%|████████▋ | 87336/100000 [00:13<00:01, 7079.05it/s]
88%|████████▊ | 88064/100000 [00:13<00:01, 7137.87it/s]
89%|████████▉ | 88780/100000 [00:13<00:01, 7117.55it/s]
89%|████████▉ | 89493/100000 [00:13<00:01, 6597.85it/s]
90%|█████████ | 90200/100000 [00:14<00:01, 6731.06it/s]
91%|█████████ | 90967/100000 [00:14<00:01, 7000.01it/s]
92%|█████████▏| 91722/100000 [00:14<00:01, 7159.89it/s]
92%|█████████▏| 92443/100000 [00:14<00:01, 7149.13it/s]
93%|█████████▎| 93168/100000 [00:14<00:00, 7174.86it/s]
94%|█████████▍| 93888/100000 [00:14<00:00, 6206.67it/s]
95%|█████████▍| 94533/100000 [00:14<00:00, 6243.15it/s]
95%|█████████▌| 95175/100000 [00:14<00:00, 5807.05it/s]
96%|█████████▌| 95773/100000 [00:14<00:00, 5325.68it/s]
96%|█████████▋| 96482/100000 [00:15<00:00, 5778.72it/s]
97%|█████████▋| 97142/100000 [00:15<00:00, 5998.32it/s]
98%|█████████▊| 97758/100000 [00:15<00:00, 5800.69it/s]
98%|█████████▊| 98412/100000 [00:15<00:00, 6001.54it/s]
99%|█████████▉| 99146/100000 [00:15<00:00, 6377.10it/s]
100%|█████████▉| 99793/100000 [00:15<00:00, 6112.05it/s]
100%|██████████| 100000/100000 [00:15<00:00, 6407.43it/s]