创建 Multigrid 新环境#
自定义环境。
import sys
from pathlib import Path
from IPython import display
root_dir = Path(".").resolve()
sys.path.extend([str(root_dir.parents[2]/"tests/gym-multigrid")])
temp_dir = root_dir/"images"
第一步:Multigrid 创建新世界#
core/agent.py
: 如果现有的动作类不满足需求,添加新世界的动作类。core/constants.py
:如果有对象特定的状态,定义
state_to_idx_{yourChosenName}
字典。如果你的环境需要尚未定义的对象,将新对象的条目添加到
OBJECT_TO_STR
字典中。
core/object.py
为新对象(如果有)定义类。
对象需要:类型、颜色、位置、编码、解码、渲染。
对象属性通过以下方式定义:can_overlap(可重叠)、can_pickup(可拾取)、can_contain(可容纳)、see_behind(可见背面)。
对象可以具有:contains(包含)、toggle(切换)。
core/world.py
如果现有世界不满足需求,添加一个新世界。
一个世界定义了对象、颜色和编码大小。
注意:编码层用于捕获不同的事物。第1层用于单元格中的对象类型,第2层用于颜色,第3层用于代理方向。我们还没有使用4、5和6层,但如果需要,它们可以用于更多特性。
core/grid.py
这是新环境的基础结构。
方法包括:copy(复制)、get(获取)、set(设置)、rotate(旋转)、slice(切片)、render(渲染)、encode(编码)。
还有:horz_wall(水平墙)、vert_wall(垂直墙)、wall_rect(矩形墙)。
第二步:Multigrid 创建环境#
在 gym_multigrid/gym_multigrid/envs/
目录下创建名为 {yourChosenName}.py
的文件。编写继承自 MultigridEnv
的环境类。
在调用父类的
__init__
方法时,您应该指定以下参数:代理列表
网格尺寸
是否使用完全或部分可观测性
每个时间步长的数量
上面定义的动作和世界类
from gym_multigrid.multigrid import MultiGridEnv
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[2], line 1
----> 1 from gym_multigrid.multigrid import MultiGridEnv
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:21
15 from gym_multigrid.core.constants import *
18 MultiGridEnvT = TypeVar("MultiGridEnvT", bound="MultiGridEnv")
---> 21 class MultiGridEnv(gym.Env):
22 """
23 2D grid world game environment
24 """
26 metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 10}
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:399, in MultiGridEnv()
393 world_cell = self.grid.get(x, y)
395 return obs_cell is not None and obs_cell.type == world_cell.type
397 def step(
398 self, actions: list[int] | NDArray[np.int_]
--> 399 ) -> tuple[NDArray[np.int_], NDArray[np.float_], bool, bool, dict]:
400 self.step_count += 1
402 order = np.random.permutation(len(actions))
File /opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.
MultiGridEnv?
Object `MultiGridEnv` not found.
你可能需要初始化/定义与你的环境相关的其他私有变量。例如,在收集游戏中,我们需要跟踪以下内容:
from gym_multigrid.envs.collect_game import CollectGameEnv
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 from gym_multigrid.envs.collect_game import CollectGameEnv
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/__init__.py:1
----> 1 from gym_multigrid.envs.collect_game import *
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/collect_game.py:4
1 import numpy as np
2 from numpy.typing import NDArray
----> 4 from gym_multigrid.multigrid import MultiGridEnv
5 from gym_multigrid.core.world import CollectWorld
6 from gym_multigrid.core.agent import CollectActions, Agent
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:21
15 from gym_multigrid.core.constants import *
18 MultiGridEnvT = TypeVar("MultiGridEnvT", bound="MultiGridEnv")
---> 21 class MultiGridEnv(gym.Env):
22 """
23 2D grid world game environment
24 """
26 metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 10}
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:399, in MultiGridEnv()
393 world_cell = self.grid.get(x, y)
395 return obs_cell is not None and obs_cell.type == world_cell.type
397 def step(
398 self, actions: list[int] | NDArray[np.int_]
--> 399 ) -> tuple[NDArray[np.int_], NDArray[np.float_], bool, bool, dict]:
400 self.step_count += 1
402 order = np.random.permutation(len(actions))
File /opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.
CollectGameEnv.__init__?
Object `CollectGameEnv.__init__` not found.
_gen_grid()
#
你必须实现这个方法,因为它没有在MultiGridEnv父类中定义。这个方法在env.reset()期间默认被调用。在这里,你需要放置所有存在于网格世界中的对象和代理。
例如,在收集游戏中,我们定义了四个边界墙,放置球体,然后放置代理。
CollectGameEnv._gen_grid??
Object `CollectGameEnv._gen_grid` not found.
place_obj()
方法由父类定义,具有以下参数:
MultiGridEnv.place_obj??
Object `MultiGridEnv.place_obj` not found.
默认情况下,该方法通过反复随机均匀地采样位置来尝试将对象放置在网格中,直到找到一个空闲的网格单元格。
如果你知道对象的坐标,你应该改用这个方法:
MultiGridEnv.put_obj??
Object `MultiGridEnv.put_obj` not found.
对于放置代理,根据需要使用上述两种方法调用此方法:
MultiGridEnv.place_agent??
Object `MultiGridEnv.place_agent` not found.
_reward()
#
MultiGridEnv._reward??
Object `MultiGridEnv._reward` not found.
当达到目标状态时,此方法会被调用。current_agent
指定哪个代理接收奖励。
如果你的环境有不同的奖励结构,你应该重写这个方法。
step()
#
该方法对你的环境动态至关重要。你应该定义这个方法,并且如果 MultiGridEnv
的 step()
方法可以按照你的环境所需的方式处理动作执行,你也可以调用它。
step()
方法的唯一必需参数是执行的动作列表。默认情况下,这些动作以随机顺序执行或者在达到最大时间步数时结束这一 episode:
MultiGridEnv.step??
Object `MultiGridEnv.step` not found.
reset()
#
与 step
方法一样,你应该为你的环境实现一个方法,并且也可以调用 MultiGridEnv
的 reset
方法,因为它重置了其他变量。
例如,在收集游戏中,我们重置了 collected_balls
的数量和 info
字典:
CollectGameEnv.reset??
Object `CollectGameEnv.reset` not found.
state
/obs
编码#
默认的网格编码是一个形状为高度 x 宽度 x encode_dim 的 numpy
数组。该方法还考虑了部分可观测性。你可能需要编写一个方法,将这个默认编码转换为最适合你的环境和代理算法的格式。
第三步:注册环境#
在 gym_multigrid/gym_multigrid/__init.py
中添加一行代码,以在 gymnasium
上注册新创建的环境。
# Collect game with 2 agents and 3 object types
# ----------------------------------------
register(
id="multigrid-collect-v0",
entry_point="gym_multigrid.envs:CollectGameEvenDist",
max_episode_steps=100,
kwargs={
"size": 10,
"num_balls": 15,
"agents_index": [3, 5], # green, purple
"balls_index": [0, 1, 2], # red, orange, yellow
"balls_reward": [1, 1, 1],
"respawn": False,
},
)