gym-multigrid 简介#

gym-multigrid 旨在提供轻量级、多智能体的网格环境。它最初基于这个多网格环境,但自那以后进行了大量修改和开发,超出了原始环境的范围。

可以直接安装 gym-multigrid

pip install gym-multigrid

亦可克隆到本地:

git clone https://github.com/Tran-Research-Group/gym-multigrid.git
import sys
from pathlib import Path
from IPython import display

root_dir = Path(".").resolve()
sys.path.extend([str(root_dir.parents[2]/"tests/gym-multigrid")])
temp_dir = root_dir/"images"

gym-multigrid Capture-the-Flag (CtF) 环境#

import numpy as np
import imageio

from gym_multigrid.envs.ctf import Ctf1v1Env, CtFMvNEnv
from gym_multigrid.policy.ctf.heuristic import (
    FightPolicy,
    CapturePolicy,
    PatrolPolicy,
    RwPolicy,
    PatrolFightPolicy,
)
from gym_multigrid.utils.map import load_text_map
import matplotlib.pyplot as plt

简单测试:

map_path: str = root_dir/"tests/assets/board.txt"

env = Ctf1v1Env(
    map_path=map_path, render_mode="human", observation_option="flattened"
)
obs, _ = env.reset()
env.render()

while True:
    action = np.random.choice(list(env.actions_set))
    obs, reward, terminated, truncated, info = env.step(action)
    env.render()
    if terminated or truncated:
        break
../../_images/6637e430861e0933e44455608e05615b9d60f0e521c594c9bae83d689d8f4221.png

设置随机种子:

# TODO: might be good idea to include seeding test for other environments
def test_ctf_random_seeding(root_dir) -> None:
    map_path: str = f"{root_dir}/tests/assets/board.txt"
    env = Ctf1v1Env(
        map_path=map_path, render_mode="human", observation_option="flattened"
    )
    env.reset(seed=1)
    array1 = env.np_random.random(10)
    env.reset(seed=1)
    array2 = env.np_random.random(10)

    np.testing.assert_allclose(array1, array2)

test_ctf_random_seeding(root_dir)

MvN CtF:

def test_ctf_mvn_human(root_dir) -> None:
    map_path: str = f"{root_dir}/tests/assets/board.txt"
    env = CtFMvNEnv(
        num_blue_agents=2,
        num_red_agents=2,
        map_path=map_path,
        render_mode="human",
        observation_option="flattened",
    )
    obs, _ = env.reset()
    env.render()

    while True:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        env.render()
        if terminated or truncated:
            break

    assert terminated or truncated

test_ctf_mvn_human(root_dir)
../../_images/6637e430861e0933e44455608e05615b9d60f0e521c594c9bae83d689d8f4221.png
def test_ctf_mvn_rgb(root_dir, temp_dir) -> None:
    map_path: str = f"{root_dir}/tests/assets/board.txt"
    env = CtFMvNEnv(
        num_blue_agents=2,
        num_red_agents=2,
        map_path=map_path,
        render_mode="rgb_array",
        observation_option="flattened",
    )
    obs, _ = env.reset()
    frames = [env.render()]
    while True:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        frames.append(env.render())
        if terminated or truncated:
            break
    temp_dir = Path(temp_dir)
    (temp_dir/"animations").mkdir(exist_ok=True)
    imageio.mimsave(temp_dir/f"animations/ctf_mvn.gif", frames, duration=0.5)

    assert Path(temp_dir/f"animations/ctf_mvn.gif").exists()
test_ctf_mvn_rgb(root_dir, temp_dir)
display.Image(temp_dir/f"animations/ctf_mvn.gif")
../../_images/bae109e2d1a257fc88d8ef7d782c6f0ff385487898ae065a1ef2eccb2a33cfd9.gif
def test_fight_policy(root_dir, temp_dir) -> None:
    (temp_dir/"animations").mkdir(exist_ok=True)
    animation_path = temp_dir/"animations/ctf_mvn_fight_policy.gif"
    map_path = f"{root_dir}/tests/assets/board.txt"

    _field_map = load_text_map(map_path)
    enemy_policy = FightPolicy()

    env = CtFMvNEnv(
        num_blue_agents=2,
        num_red_agents=2,
        map_path=map_path,
        render_mode="human",
        observation_option="flattened",
        enemy_policies=[enemy_policy, RwPolicy()],
    )

    obs, _ = env.reset()
    frames = [env.render()]
    while True:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        frames.append(env.render())
        if terminated or truncated:
            break

    imageio.mimsave(animation_path, frames, duration=0.5)
    assert animation_path.exists()

test_fight_policy(root_dir, temp_dir)
../../_images/6637e430861e0933e44455608e05615b9d60f0e521c594c9bae83d689d8f4221.png
display.Image(temp_dir/"animations/ctf_mvn_fight_policy.gif")
../../_images/785375f1e82ed325f50bc3ec63508577c07d12c24406589a8a0108f117808082.gif

def test_capture_policy(map_path, animation_path) -> None:
    field_map = load_text_map(map_path)
    enemy_policy = CapturePolicy(field_map)

    env = CtFMvNEnv(
        num_blue_agents=2,
        num_red_agents=2,
        map_path=map_path,
        render_mode="human",
        observation_option="flattened",
        enemy_policies=[enemy_policy, RwPolicy()],
    )

    obs, _ = env.reset()
    frames = [env.render()]
    while True:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        frames.append(env.render())
        if terminated or truncated:
            break

    imageio.mimsave(animation_path, frames, duration=0.5)

    assert Path(animation_path).exists

animation_path: str = f"{temp_dir}/animations/ctf_mvn_capture_policy.gif"
map_path: str = f"{root_dir}/tests/assets/board.txt"
test_capture_policy(map_path, animation_path)
../../_images/6637e430861e0933e44455608e05615b9d60f0e521c594c9bae83d689d8f4221.png
display.Image(f"{temp_dir}/animations/ctf_mvn_capture_policy.gif")
../../_images/0960fbc1595f5a4b03f9be61a98de24d28c9099313ffeff1d41747d691303782.gif
def test_patrol_policy(animation_path, map_path) -> None:
    field_map = load_text_map(map_path)
    enemy_policy = PatrolPolicy(field_map)

    env = CtFMvNEnv(
        num_blue_agents=2,
        num_red_agents=2,
        map_path=map_path,
        render_mode="human",
        observation_option="flattened",
        enemy_policies=enemy_policy,
    )

    obs, _ = env.reset()
    frames = [env.render()]
    while True:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        frames.append(env.render())
        if terminated or truncated:
            break

    imageio.mimsave(animation_path, frames, duration=0.5)

    assert Path(animation_path).exists

animation_path: str = f"{temp_dir}/animations/ctf_mvn_patrol_policy.gif"
map_path: str = f"{root_dir}/tests/assets/board.txt"
test_patrol_policy(animation_path, map_path)
../../_images/6637e430861e0933e44455608e05615b9d60f0e521c594c9bae83d689d8f4221.png
display.Image(f"{temp_dir}/animations/ctf_mvn_patrol_policy.gif")
../../_images/3fcce336dfda25bd0e165703cc43459508168c744323c16f6468bfb590ace3a0.gif
def test_patrol_fight_policy(map_path, animation_path) -> None:
    field_map = load_text_map(map_path)
    enemy_policy = PatrolFightPolicy(field_map)

    env = CtFMvNEnv(
        num_blue_agents=2,
        num_red_agents=2,
        map_path=map_path,
        render_mode="human",
        observation_option="flattened",
        enemy_policies=enemy_policy,
    )

    obs, _ = env.reset()
    frames = [env.render()]
    while True:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        frames.append(env.render())
        if terminated or truncated:
            break

    imageio.mimsave(animation_path, frames, duration=0.5)

    assert Path(animation_path).exists

animation_path: str = f"{temp_dir}/animations/ctf_mvn_patrol_fight_policy.gif"
map_path: str = f"{root_dir}/tests/assets/board.txt"
test_patrol_fight_policy(map_path, animation_path)
../../_images/6637e430861e0933e44455608e05615b9d60f0e521c594c9bae83d689d8f4221.png
display.Image(f"{temp_dir}/animations/ctf_mvn_patrol_fight_policy.gif")
../../_images/2479a7d0da5dcf5aae2c2a6f761fe94d6b7b34907bca7c66a68bb65faa3a820d.gif
def test_mvn_ctf_render(map_path, img_save_path) -> None:
    env = CtFMvNEnv(
        num_blue_agents=2,
        num_red_agents=2,
        map_path=map_path,
        render_mode="human",
        observation_option="flattened",
    )
    obs, _ = env.reset()

    for _ in range(1):
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)

    img = env.render()
    plt.imsave(img_save_path, img, dpi=600)

    assert Path(img_save_path).exists

(temp_dir/"plots").mkdir(parents=True, exist_ok=True)
img_save_path: str = f"{temp_dir}/plots/mvn_ctf_render.png"
map_path: str = f"{root_dir}/tests/assets/board.txt"
test_mvn_ctf_render(map_path, img_save_path)
../../_images/6637e430861e0933e44455608e05615b9d60f0e521c594c9bae83d689d8f4221.png
display.Image(f"{temp_dir}/plots/mvn_ctf_render.png")
../../_images/b55310145962a36e2f9734295c584e22c1c130af6764d4cd85e4631a74f76679.png

Maze 环境#

from gym_multigrid.envs.maze import MazeSingleAgentEnv


map_path: str = f"{root_dir}/tests/assets/board_maze.txt"

env = MazeSingleAgentEnv(
    map_path=map_path, render_mode="human", max_steps=200, step_penalty_ratio=0
)
obs, _ = env.reset()
env.render()

while True:
    action = np.random.choice(list(env.actions_set))
    obs, reward, terminated, truncated, info = env.step(action)
    env.render()
    if terminated or truncated:
        break
../../_images/6637e430861e0933e44455608e05615b9d60f0e521c594c9bae83d689d8f4221.png

Collect 游戏#

Collect Game Respawn

Attribute

Description

Action Space

Discrete(4)

Observation Space

np.array of shape grid.width x grid.height

Observation Encoding

(OBJECT_IDX, COLOR_IDX, STATE)

Reward

(0, 1)

Number of Agents

2

Termination Condition

None

Truncation Steps

50

Creation

gymnasium.make("multigrid-collect-respawn-clustered-v0")

代理在网格中移动以收集物体。物体在被收集后会在随机位置重新生成。

import gymnasium as gym
from gym_multigrid.envs.collect_game import CollectGameEnv


kwargs={
    "size": 10,
    "num_balls": [15,],
    "agents_index": [3, 5],  # green, purple
    "balls_index": [0, 1, 2],  # red, orange, yellow
    "balls_reward": [1, 1, 1],
    "respawn": False,
}
env = CollectGameEnv(**kwargs)
frames = [env.render()]
obs, info = env.reset()
while True:
    actions = [env.action_space.sample() for a in env.agents]
    obs, reward, terminated, truncated, info = env.step(actions)
    frames.append(env.render())
    if terminated or truncated:
        print(f"episode ended after {env.step_count} steps")
        print(f"agents collected {env.collected_balls} objects")
        break
temp_dir = Path(temp_dir)
(temp_dir/"animations").mkdir(parents=True, exist_ok=True)
imageio.mimsave(temp_dir/f"animations/multigrid-collect.gif", frames, duration=0.5)
display.Image(temp_dir/f"animations/multigrid-collect.gif")
../../_images/ef248d774b8406daefe34f8332f686434f31345033da8eab4e7d53e813327266.gif