gym-multigrid
简介#
gym-multigrid
旨在提供轻量级、多智能体的网格环境。它最初基于这个多网格环境,但自那以后进行了大量修改和开发,超出了原始环境的范围。
可以直接安装 gym-multigrid:
pip install gym-multigrid
亦可克隆到本地:
git clone https://github.com/Tran-Research-Group/gym-multigrid.git
import sys
from pathlib import Path
from IPython import display
root_dir = Path(".").resolve()
sys.path.extend([str(root_dir.parents[2]/"tests/gym-multigrid")])
temp_dir = root_dir/"images"
gym-multigrid
Capture-the-Flag (CtF) 环境#
import numpy as np
import imageio
from gym_multigrid.envs.ctf import Ctf1v1Env, CtFMvNEnv
from gym_multigrid.policy.ctf.heuristic import (
FightPolicy,
CapturePolicy,
PatrolPolicy,
RwPolicy,
PatrolFightPolicy,
)
from gym_multigrid.utils.map import load_text_map
import matplotlib.pyplot as plt
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[2], line 4
1 import numpy as np
2 import imageio
----> 4 from gym_multigrid.envs.ctf import Ctf1v1Env, CtFMvNEnv
5 from gym_multigrid.policy.ctf.heuristic import (
6 FightPolicy,
7 CapturePolicy,
(...)
10 PatrolFightPolicy,
11 )
12 from gym_multigrid.utils.map import load_text_map
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/__init__.py:1
----> 1 from gym_multigrid.envs.collect_game import *
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/collect_game.py:4
1 import numpy as np
2 from numpy.typing import NDArray
----> 4 from gym_multigrid.multigrid import MultiGridEnv
5 from gym_multigrid.core.world import CollectWorld
6 from gym_multigrid.core.agent import CollectActions, Agent
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:21
15 from gym_multigrid.core.constants import *
18 MultiGridEnvT = TypeVar("MultiGridEnvT", bound="MultiGridEnv")
---> 21 class MultiGridEnv(gym.Env):
22 """
23 2D grid world game environment
24 """
26 metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 10}
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:399, in MultiGridEnv()
393 world_cell = self.grid.get(x, y)
395 return obs_cell is not None and obs_cell.type == world_cell.type
397 def step(
398 self, actions: list[int] | NDArray[np.int_]
--> 399 ) -> tuple[NDArray[np.int_], NDArray[np.float_], bool, bool, dict]:
400 self.step_count += 1
402 order = np.random.permutation(len(actions))
File /opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.
简单测试:
map_path: str = root_dir/"tests/assets/board.txt"
env = Ctf1v1Env(
map_path=map_path, render_mode="human", observation_option="flattened"
)
obs, _ = env.reset()
env.render()
while True:
action = np.random.choice(list(env.actions_set))
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if terminated or truncated:
break
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[3], line 3
1 map_path: str = root_dir/"tests/assets/board.txt"
----> 3 env = Ctf1v1Env(
4 map_path=map_path, render_mode="human", observation_option="flattened"
5 )
6 obs, _ = env.reset()
7 env.render()
NameError: name 'Ctf1v1Env' is not defined
设置随机种子:
# TODO: might be good idea to include seeding test for other environments
def test_ctf_random_seeding(root_dir) -> None:
map_path: str = f"{root_dir}/tests/assets/board.txt"
env = Ctf1v1Env(
map_path=map_path, render_mode="human", observation_option="flattened"
)
env.reset(seed=1)
array1 = env.np_random.random(10)
env.reset(seed=1)
array2 = env.np_random.random(10)
np.testing.assert_allclose(array1, array2)
test_ctf_random_seeding(root_dir)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[4], line 14
10 array2 = env.np_random.random(10)
12 np.testing.assert_allclose(array1, array2)
---> 14 test_ctf_random_seeding(root_dir)
Cell In[4], line 4, in test_ctf_random_seeding(root_dir)
2 def test_ctf_random_seeding(root_dir) -> None:
3 map_path: str = f"{root_dir}/tests/assets/board.txt"
----> 4 env = Ctf1v1Env(
5 map_path=map_path, render_mode="human", observation_option="flattened"
6 )
7 env.reset(seed=1)
8 array1 = env.np_random.random(10)
NameError: name 'Ctf1v1Env' is not defined
MvN CtF:
def test_ctf_mvn_human(root_dir) -> None:
map_path: str = f"{root_dir}/tests/assets/board.txt"
env = CtFMvNEnv(
num_blue_agents=2,
num_red_agents=2,
map_path=map_path,
render_mode="human",
observation_option="flattened",
)
obs, _ = env.reset()
env.render()
while True:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if terminated or truncated:
break
assert terminated or truncated
test_ctf_mvn_human(root_dir)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[5], line 22
18 break
20 assert terminated or truncated
---> 22 test_ctf_mvn_human(root_dir)
Cell In[5], line 3, in test_ctf_mvn_human(root_dir)
1 def test_ctf_mvn_human(root_dir) -> None:
2 map_path: str = f"{root_dir}/tests/assets/board.txt"
----> 3 env = CtFMvNEnv(
4 num_blue_agents=2,
5 num_red_agents=2,
6 map_path=map_path,
7 render_mode="human",
8 observation_option="flattened",
9 )
10 obs, _ = env.reset()
11 env.render()
NameError: name 'CtFMvNEnv' is not defined
def test_ctf_mvn_rgb(root_dir, temp_dir) -> None:
map_path: str = f"{root_dir}/tests/assets/board.txt"
env = CtFMvNEnv(
num_blue_agents=2,
num_red_agents=2,
map_path=map_path,
render_mode="rgb_array",
observation_option="flattened",
)
obs, _ = env.reset()
frames = [env.render()]
while True:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
frames.append(env.render())
if terminated or truncated:
break
temp_dir = Path(temp_dir)
(temp_dir/"animations").mkdir(exist_ok=True)
imageio.mimsave(temp_dir/f"animations/ctf_mvn.gif", frames, duration=0.5)
assert Path(temp_dir/f"animations/ctf_mvn.gif").exists()
test_ctf_mvn_rgb(root_dir, temp_dir)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[6], line 23
20 imageio.mimsave(temp_dir/f"animations/ctf_mvn.gif", frames, duration=0.5)
22 assert Path(temp_dir/f"animations/ctf_mvn.gif").exists()
---> 23 test_ctf_mvn_rgb(root_dir, temp_dir)
Cell In[6], line 3, in test_ctf_mvn_rgb(root_dir, temp_dir)
1 def test_ctf_mvn_rgb(root_dir, temp_dir) -> None:
2 map_path: str = f"{root_dir}/tests/assets/board.txt"
----> 3 env = CtFMvNEnv(
4 num_blue_agents=2,
5 num_red_agents=2,
6 map_path=map_path,
7 render_mode="rgb_array",
8 observation_option="flattened",
9 )
10 obs, _ = env.reset()
11 frames = [env.render()]
NameError: name 'CtFMvNEnv' is not defined
display.Image(temp_dir/f"animations/ctf_mvn.gif")
def test_fight_policy(root_dir, temp_dir) -> None:
(temp_dir/"animations").mkdir(exist_ok=True)
animation_path = temp_dir/"animations/ctf_mvn_fight_policy.gif"
map_path = f"{root_dir}/tests/assets/board.txt"
_field_map = load_text_map(map_path)
enemy_policy = FightPolicy()
env = CtFMvNEnv(
num_blue_agents=2,
num_red_agents=2,
map_path=map_path,
render_mode="human",
observation_option="flattened",
enemy_policies=[enemy_policy, RwPolicy()],
)
obs, _ = env.reset()
frames = [env.render()]
while True:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
frames.append(env.render())
if terminated or truncated:
break
imageio.mimsave(animation_path, frames, duration=0.5)
assert animation_path.exists()
test_fight_policy(root_dir, temp_dir)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[8], line 30
27 imageio.mimsave(animation_path, frames, duration=0.5)
28 assert animation_path.exists()
---> 30 test_fight_policy(root_dir, temp_dir)
Cell In[8], line 6, in test_fight_policy(root_dir, temp_dir)
3 animation_path = temp_dir/"animations/ctf_mvn_fight_policy.gif"
4 map_path = f"{root_dir}/tests/assets/board.txt"
----> 6 _field_map = load_text_map(map_path)
7 enemy_policy = FightPolicy()
9 env = CtFMvNEnv(
10 num_blue_agents=2,
11 num_red_agents=2,
(...)
15 enemy_policies=[enemy_policy, RwPolicy()],
16 )
NameError: name 'load_text_map' is not defined
display.Image(temp_dir/"animations/ctf_mvn_fight_policy.gif")
def test_capture_policy(map_path, animation_path) -> None:
field_map = load_text_map(map_path)
enemy_policy = CapturePolicy(field_map)
env = CtFMvNEnv(
num_blue_agents=2,
num_red_agents=2,
map_path=map_path,
render_mode="human",
observation_option="flattened",
enemy_policies=[enemy_policy, RwPolicy()],
)
obs, _ = env.reset()
frames = [env.render()]
while True:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
frames.append(env.render())
if terminated or truncated:
break
imageio.mimsave(animation_path, frames, duration=0.5)
assert Path(animation_path).exists
animation_path: str = f"{temp_dir}/animations/ctf_mvn_capture_policy.gif"
map_path: str = f"{root_dir}/tests/assets/board.txt"
test_capture_policy(map_path, animation_path)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[10], line 29
27 animation_path: str = f"{temp_dir}/animations/ctf_mvn_capture_policy.gif"
28 map_path: str = f"{root_dir}/tests/assets/board.txt"
---> 29 test_capture_policy(map_path, animation_path)
Cell In[10], line 2, in test_capture_policy(map_path, animation_path)
1 def test_capture_policy(map_path, animation_path) -> None:
----> 2 field_map = load_text_map(map_path)
3 enemy_policy = CapturePolicy(field_map)
5 env = CtFMvNEnv(
6 num_blue_agents=2,
7 num_red_agents=2,
(...)
11 enemy_policies=[enemy_policy, RwPolicy()],
12 )
NameError: name 'load_text_map' is not defined
display.Image(f"{temp_dir}/animations/ctf_mvn_capture_policy.gif")
def test_patrol_policy(animation_path, map_path) -> None:
field_map = load_text_map(map_path)
enemy_policy = PatrolPolicy(field_map)
env = CtFMvNEnv(
num_blue_agents=2,
num_red_agents=2,
map_path=map_path,
render_mode="human",
observation_option="flattened",
enemy_policies=enemy_policy,
)
obs, _ = env.reset()
frames = [env.render()]
while True:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
frames.append(env.render())
if terminated or truncated:
break
imageio.mimsave(animation_path, frames, duration=0.5)
assert Path(animation_path).exists
animation_path: str = f"{temp_dir}/animations/ctf_mvn_patrol_policy.gif"
map_path: str = f"{root_dir}/tests/assets/board.txt"
test_patrol_policy(animation_path, map_path)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[12], line 29
27 animation_path: str = f"{temp_dir}/animations/ctf_mvn_patrol_policy.gif"
28 map_path: str = f"{root_dir}/tests/assets/board.txt"
---> 29 test_patrol_policy(animation_path, map_path)
Cell In[12], line 2, in test_patrol_policy(animation_path, map_path)
1 def test_patrol_policy(animation_path, map_path) -> None:
----> 2 field_map = load_text_map(map_path)
3 enemy_policy = PatrolPolicy(field_map)
5 env = CtFMvNEnv(
6 num_blue_agents=2,
7 num_red_agents=2,
(...)
11 enemy_policies=enemy_policy,
12 )
NameError: name 'load_text_map' is not defined
display.Image(f"{temp_dir}/animations/ctf_mvn_patrol_policy.gif")
def test_patrol_fight_policy(map_path, animation_path) -> None:
field_map = load_text_map(map_path)
enemy_policy = PatrolFightPolicy(field_map)
env = CtFMvNEnv(
num_blue_agents=2,
num_red_agents=2,
map_path=map_path,
render_mode="human",
observation_option="flattened",
enemy_policies=enemy_policy,
)
obs, _ = env.reset()
frames = [env.render()]
while True:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
frames.append(env.render())
if terminated or truncated:
break
imageio.mimsave(animation_path, frames, duration=0.5)
assert Path(animation_path).exists
animation_path: str = f"{temp_dir}/animations/ctf_mvn_patrol_fight_policy.gif"
map_path: str = f"{root_dir}/tests/assets/board.txt"
test_patrol_fight_policy(map_path, animation_path)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[14], line 29
27 animation_path: str = f"{temp_dir}/animations/ctf_mvn_patrol_fight_policy.gif"
28 map_path: str = f"{root_dir}/tests/assets/board.txt"
---> 29 test_patrol_fight_policy(map_path, animation_path)
Cell In[14], line 2, in test_patrol_fight_policy(map_path, animation_path)
1 def test_patrol_fight_policy(map_path, animation_path) -> None:
----> 2 field_map = load_text_map(map_path)
3 enemy_policy = PatrolFightPolicy(field_map)
5 env = CtFMvNEnv(
6 num_blue_agents=2,
7 num_red_agents=2,
(...)
11 enemy_policies=enemy_policy,
12 )
NameError: name 'load_text_map' is not defined
display.Image(f"{temp_dir}/animations/ctf_mvn_patrol_fight_policy.gif")
def test_mvn_ctf_render(map_path, img_save_path) -> None:
env = CtFMvNEnv(
num_blue_agents=2,
num_red_agents=2,
map_path=map_path,
render_mode="human",
observation_option="flattened",
)
obs, _ = env.reset()
for _ in range(1):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
img = env.render()
plt.imsave(img_save_path, img, dpi=600)
assert Path(img_save_path).exists
(temp_dir/"plots").mkdir(parents=True, exist_ok=True)
img_save_path: str = f"{temp_dir}/plots/mvn_ctf_render.png"
map_path: str = f"{root_dir}/tests/assets/board.txt"
test_mvn_ctf_render(map_path, img_save_path)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[16], line 23
21 img_save_path: str = f"{temp_dir}/plots/mvn_ctf_render.png"
22 map_path: str = f"{root_dir}/tests/assets/board.txt"
---> 23 test_mvn_ctf_render(map_path, img_save_path)
Cell In[16], line 2, in test_mvn_ctf_render(map_path, img_save_path)
1 def test_mvn_ctf_render(map_path, img_save_path) -> None:
----> 2 env = CtFMvNEnv(
3 num_blue_agents=2,
4 num_red_agents=2,
5 map_path=map_path,
6 render_mode="human",
7 observation_option="flattened",
8 )
9 obs, _ = env.reset()
11 for _ in range(1):
NameError: name 'CtFMvNEnv' is not defined
display.Image(f"{temp_dir}/plots/mvn_ctf_render.png")
Maze 环境#
from gym_multigrid.envs.maze import MazeSingleAgentEnv
map_path: str = f"{root_dir}/tests/assets/board_maze.txt"
env = MazeSingleAgentEnv(
map_path=map_path, render_mode="human", max_steps=200, step_penalty_ratio=0
)
obs, _ = env.reset()
env.render()
while True:
action = np.random.choice(list(env.actions_set))
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if terminated or truncated:
break
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[18], line 1
----> 1 from gym_multigrid.envs.maze import MazeSingleAgentEnv
4 map_path: str = f"{root_dir}/tests/assets/board_maze.txt"
6 env = MazeSingleAgentEnv(
7 map_path=map_path, render_mode="human", max_steps=200, step_penalty_ratio=0
8 )
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/__init__.py:1
----> 1 from gym_multigrid.envs.collect_game import *
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/collect_game.py:4
1 import numpy as np
2 from numpy.typing import NDArray
----> 4 from gym_multigrid.multigrid import MultiGridEnv
5 from gym_multigrid.core.world import CollectWorld
6 from gym_multigrid.core.agent import CollectActions, Agent
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:21
15 from gym_multigrid.core.constants import *
18 MultiGridEnvT = TypeVar("MultiGridEnvT", bound="MultiGridEnv")
---> 21 class MultiGridEnv(gym.Env):
22 """
23 2D grid world game environment
24 """
26 metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 10}
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:399, in MultiGridEnv()
393 world_cell = self.grid.get(x, y)
395 return obs_cell is not None and obs_cell.type == world_cell.type
397 def step(
398 self, actions: list[int] | NDArray[np.int_]
--> 399 ) -> tuple[NDArray[np.int_], NDArray[np.float_], bool, bool, dict]:
400 self.step_count += 1
402 order = np.random.permutation(len(actions))
File /opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.
Collect 游戏#
Attribute |
Description |
---|---|
Action Space |
|
Observation Space |
|
Observation Encoding |
|
Reward |
|
Number of Agents |
|
Termination Condition |
|
Truncation Steps |
|
Creation |
|
代理在网格中移动以收集物体。物体在被收集后会在随机位置重新生成。
import gymnasium as gym
from gym_multigrid.envs.collect_game import CollectGameEnv
kwargs={
"size": 10,
"num_balls": [15,],
"agents_index": [3, 5], # green, purple
"balls_index": [0, 1, 2], # red, orange, yellow
"balls_reward": [1, 1, 1],
"respawn": False,
}
env = CollectGameEnv(**kwargs)
frames = [env.render()]
obs, info = env.reset()
while True:
actions = [env.action_space.sample() for a in env.agents]
obs, reward, terminated, truncated, info = env.step(actions)
frames.append(env.render())
if terminated or truncated:
print(f"episode ended after {env.step_count} steps")
print(f"agents collected {env.collected_balls} objects")
break
temp_dir = Path(temp_dir)
(temp_dir/"animations").mkdir(parents=True, exist_ok=True)
imageio.mimsave(temp_dir/f"animations/multigrid-collect.gif", frames, duration=0.5)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[19], line 2
1 import gymnasium as gym
----> 2 from gym_multigrid.envs.collect_game import CollectGameEnv
5 kwargs={
6 "size": 10,
7 "num_balls": [15,],
(...)
11 "respawn": False,
12 }
13 env = CollectGameEnv(**kwargs)
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/__init__.py:1
----> 1 from gym_multigrid.envs.collect_game import *
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/collect_game.py:4
1 import numpy as np
2 from numpy.typing import NDArray
----> 4 from gym_multigrid.multigrid import MultiGridEnv
5 from gym_multigrid.core.world import CollectWorld
6 from gym_multigrid.core.agent import CollectActions, Agent
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:21
15 from gym_multigrid.core.constants import *
18 MultiGridEnvT = TypeVar("MultiGridEnvT", bound="MultiGridEnv")
---> 21 class MultiGridEnv(gym.Env):
22 """
23 2D grid world game environment
24 """
26 metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 10}
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:399, in MultiGridEnv()
393 world_cell = self.grid.get(x, y)
395 return obs_cell is not None and obs_cell.type == world_cell.type
397 def step(
398 self, actions: list[int] | NDArray[np.int_]
--> 399 ) -> tuple[NDArray[np.int_], NDArray[np.float_], bool, bool, dict]:
400 self.step_count += 1
402 order = np.random.permutation(len(actions))
File /opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.
display.Image(temp_dir/f"animations/multigrid-collect.gif")