创建 Multigrid 新环境#

自定义环境。

import sys
from pathlib import Path
from IPython import display

root_dir = Path(".").resolve()
sys.path.extend([str(root_dir.parents[2]/"tests/gym-multigrid")])
temp_dir = root_dir/"images"

第一步:Multigrid 创建新世界#

  • core/agent.py: 如果现有的动作类不满足需求,添加新世界的动作类。

  • core/constants.py:

    • 如果有对象特定的状态,定义 state_to_idx_{yourChosenName} 字典。

    • 如果你的环境需要尚未定义的对象,将新对象的条目添加到 OBJECT_TO_STR 字典中。

  • core/object.py

    • 为新对象(如果有)定义类。

    • 对象需要:类型、颜色、位置、编码、解码、渲染。

    • 对象属性通过以下方式定义:can_overlap(可重叠)、can_pickup(可拾取)、can_contain(可容纳)、see_behind(可见背面)。

    • 对象可以具有:contains(包含)、toggle(切换)。

  • core/world.py

    • 如果现有世界不满足需求,添加一个新世界。

    • 一个世界定义了对象、颜色和编码大小。

    • 注意:编码层用于捕获不同的事物。第1层用于单元格中的对象类型,第2层用于颜色,第3层用于代理方向。我们还没有使用4、5和6层,但如果需要,它们可以用于更多特性。

  • core/grid.py

    • 这是新环境的基础结构。

    • 方法包括:copy(复制)、get(获取)、set(设置)、rotate(旋转)、slice(切片)、render(渲染)、encode(编码)。

    • 还有:horz_wall(水平墙)、vert_wall(垂直墙)、wall_rect(矩形墙)。

第二步:Multigrid 创建环境#

gym_multigrid/gym_multigrid/envs/ 目录下创建名为 {yourChosenName}.py 的文件。编写继承自 MultigridEnv 的环境类。

  • 在调用父类的 __init__ 方法时,您应该指定以下参数:

    • 代理列表

    • 网格尺寸

    • 是否使用完全或部分可观测性

    • 每个时间步长的数量

    • 上面定义的动作和世界类

from gym_multigrid.multigrid import MultiGridEnv
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[2], line 1
----> 1 from gym_multigrid.multigrid import MultiGridEnv

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:21
     15 from gym_multigrid.core.constants import *
     18 MultiGridEnvT = TypeVar("MultiGridEnvT", bound="MultiGridEnv")
---> 21 class MultiGridEnv(gym.Env):
     22     """
     23     2D grid world game environment
     24     """
     26     metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 10}

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:399, in MultiGridEnv()
    393     world_cell = self.grid.get(x, y)
    395     return obs_cell is not None and obs_cell.type == world_cell.type
    397 def step(
    398     self, actions: list[int] | NDArray[np.int_]
--> 399 ) -> tuple[NDArray[np.int_], NDArray[np.float_], bool, bool, dict]:
    400     self.step_count += 1
    402     order = np.random.permutation(len(actions))

File /opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.
MultiGridEnv?
Object `MultiGridEnv` not found.

你可能需要初始化/定义与你的环境相关的其他私有变量。例如,在收集游戏中,我们需要跟踪以下内容:

from gym_multigrid.envs.collect_game import CollectGameEnv
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 from gym_multigrid.envs.collect_game import CollectGameEnv

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/__init__.py:1
----> 1 from gym_multigrid.envs.collect_game import *

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/collect_game.py:4
      1 import numpy as np
      2 from numpy.typing import NDArray
----> 4 from gym_multigrid.multigrid import MultiGridEnv
      5 from gym_multigrid.core.world import CollectWorld
      6 from gym_multigrid.core.agent import CollectActions, Agent

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:21
     15 from gym_multigrid.core.constants import *
     18 MultiGridEnvT = TypeVar("MultiGridEnvT", bound="MultiGridEnv")
---> 21 class MultiGridEnv(gym.Env):
     22     """
     23     2D grid world game environment
     24     """
     26     metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 10}

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:399, in MultiGridEnv()
    393     world_cell = self.grid.get(x, y)
    395     return obs_cell is not None and obs_cell.type == world_cell.type
    397 def step(
    398     self, actions: list[int] | NDArray[np.int_]
--> 399 ) -> tuple[NDArray[np.int_], NDArray[np.float_], bool, bool, dict]:
    400     self.step_count += 1
    402     order = np.random.permutation(len(actions))

File /opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.
CollectGameEnv.__init__?
Object `CollectGameEnv.__init__` not found.

_gen_grid()#

你必须实现这个方法,因为它没有在MultiGridEnv父类中定义。这个方法在env.reset()期间默认被调用。在这里,你需要放置所有存在于网格世界中的对象和代理。

例如,在收集游戏中,我们定义了四个边界墙,放置球体,然后放置代理。

CollectGameEnv._gen_grid??
Object `CollectGameEnv._gen_grid` not found.

place_obj() 方法由父类定义,具有以下参数:

MultiGridEnv.place_obj??
Object `MultiGridEnv.place_obj` not found.

默认情况下,该方法通过反复随机均匀地采样位置来尝试将对象放置在网格中,直到找到一个空闲的网格单元格。

如果你知道对象的坐标,你应该改用这个方法:

MultiGridEnv.put_obj??
Object `MultiGridEnv.put_obj` not found.

对于放置代理,根据需要使用上述两种方法调用此方法:

MultiGridEnv.place_agent??
Object `MultiGridEnv.place_agent` not found.

_reward()#

MultiGridEnv._reward??
Object `MultiGridEnv._reward` not found.

当达到目标状态时,此方法会被调用。current_agent 指定哪个代理接收奖励。

如果你的环境有不同的奖励结构,你应该重写这个方法。

step()#

该方法对你的环境动态至关重要。你应该定义这个方法,并且如果 MultiGridEnvstep() 方法可以按照你的环境所需的方式处理动作执行,你也可以调用它。

step() 方法的唯一必需参数是执行的动作列表。默认情况下,这些动作以随机顺序执行或者在达到最大时间步数时结束这一 episode:

MultiGridEnv.step??
Object `MultiGridEnv.step` not found.

reset()#

step 方法一样,你应该为你的环境实现一个方法,并且也可以调用 MultiGridEnvreset 方法,因为它重置了其他变量。

例如,在收集游戏中,我们重置了 collected_balls 的数量和 info 字典:

CollectGameEnv.reset??
Object `CollectGameEnv.reset` not found.

state/obs 编码#

默认的网格编码是一个形状为高度 x 宽度 x encode_dim 的 numpy数组。该方法还考虑了部分可观测性。你可能需要编写一个方法,将这个默认编码转换为最适合你的环境和代理算法的格式。

第三步:注册环境#

gym_multigrid/gym_multigrid/__init.py 中添加一行代码,以在 gymnasium 上注册新创建的环境。

# Collect game with 2 agents and 3 object types
# ----------------------------------------
register(
    id="multigrid-collect-v0",
    entry_point="gym_multigrid.envs:CollectGameEvenDist",
    max_episode_steps=100,
    kwargs={
        "size": 10,
        "num_balls": 15,
        "agents_index": [3, 5],  # green, purple
        "balls_index": [0, 1, 2],  # red, orange, yellow
        "balls_reward": [1, 1, 1],
        "respawn": False,
    },
)