创建 Collect 环境包装器

创建 Collect 环境包装器#

import logging
import sys
from pathlib import Path
from d2py.utils.log_config import config_logging

root_dir = Path(".").resolve()
sys.path.extend([str(root_dir.parents[3]/"tests/gym-multigrid")])
temp_dir = root_dir/"images"
temp_dir.mkdir(parents=True, exist_ok=True)
logger_name = "gym_multigrid"
logger = logging.getLogger(logger_name)
config_logging(f'{temp_dir}/{logger_name}.log', logger_name, maxBytes=50000, backupCount=2)
from pathlib import Path
import imageio
from gym_multigrid.envs.collect_game import CollectGameEnv

kwargs={
    "size": 15,
    "num_balls": [5,],
    "agents_index": [1, 2, 3],  # green, purple
    "balls_index": [0,],  # red, orange, yellow
    "balls_reward": [1,],
    "respawn": False,
}
origin_env = CollectGameEnv(**kwargs)
frames = [origin_env.render()]
obs, info = origin_env.reset()
while True:
    actions = [origin_env.action_space.sample() for a in origin_env.agents]
    obs, reward, terminated, truncated, info = origin_env.step(actions)
    frames.append(origin_env.render())
    if terminated or truncated:
        logger.info(f"episode ended after {origin_env.step_count} steps")
        logger.info(f"agents collected {origin_env.collected_balls} objects")
        break
temp_dir = Path(temp_dir)
(temp_dir/"animations").mkdir(parents=True, exist_ok=True)
imageio.mimsave(temp_dir/f"animations/multigrid-collect.gif", frames, duration=0.5)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[2], line 3
      1 from pathlib import Path
      2 import imageio
----> 3 from gym_multigrid.envs.collect_game import CollectGameEnv
      5 kwargs={
      6     "size": 15,
      7     "num_balls": [5,],
   (...)
     11     "respawn": False,
     12 }
     13 origin_env = CollectGameEnv(**kwargs)

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/__init__.py:1
----> 1 from gym_multigrid.envs.collect_game import *

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/collect_game.py:4
      1 import numpy as np
      2 from numpy.typing import NDArray
----> 4 from gym_multigrid.multigrid import MultiGridEnv
      5 from gym_multigrid.core.world import CollectWorld
      6 from gym_multigrid.core.agent import CollectActions, Agent

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:21
     15 from gym_multigrid.core.constants import *
     18 MultiGridEnvT = TypeVar("MultiGridEnvT", bound="MultiGridEnv")
---> 21 class MultiGridEnv(gym.Env):
     22     """
     23     2D grid world game environment
     24     """
     26     metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 10}

File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:399, in MultiGridEnv()
    393     world_cell = self.grid.get(x, y)
    395     return obs_cell is not None and obs_cell.type == world_cell.type
    397 def step(
    398     self, actions: list[int] | NDArray[np.int_]
--> 399 ) -> tuple[NDArray[np.int_], NDArray[np.float_], bool, bool, dict]:
    400     self.step_count += 1
    402     order = np.random.permutation(len(actions))

File /opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.
from IPython import display
display.Image(temp_dir/f"animations/multigrid-collect.gif")
../../../_images/0b1520c6582df415c8dd6ff4a2bb0f77045ba5386c09fb972ea861ed0b03d8c4.gif

创建初始环境:

from dataclasses import dataclass
from typing import Any, SupportsFloat
import gymnasium as gym
from gymnasium import Wrapper, Env
from gymnasium.core import WrapperActType, WrapperObsType, RenderFrame, ObsType, ActType
from gym_multigrid.core.constants import TILE_PIXELS

@dataclass
class RewardWrapper(Wrapper):
    env: Env[ObsType, ActType]
    logger_name: str # 日志名称

    def __post_init__(self):
        # self.metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
        self.logger = logging.getLogger(self.logger_name)
        self.obs_shape = self.env.observation_space.shape
        self.agents = self.env.agents

    # def render(self, close=False, highlight=False, tile_size=TILE_PIXELS) -> RenderFrame | list[RenderFrame] | None:
    #     return self.env.render(close=close, highlight=highlight, tile_size=tile_size)

    def step(
        self, action: WrapperActType
    ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.logger.info(f"obs, reward, terminated, truncated, info: {obs, reward, terminated, truncated, info}")
        return obs, reward, terminated, truncated, info
env = RewardWrapper(origin_env, logger_name)
frames = [env.render()]
obs, info = env.reset()
while True:
    actions = [env.action_space.sample() for a in env.agents]
    obs, reward, terminated, truncated, info = env.step(actions)
    frames.append(env.render())
    if terminated or truncated:
        logger.info(f"episode ended after {env.step_count} steps")
        logger.info(f"agents collected {env.collected_balls} objects")
        break
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 env = RewardWrapper(origin_env, logger_name)
      2 frames = [env.render()]
      3 obs, info = env.reset()

NameError: name 'origin_env' is not defined
env.obs_shape
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 env.obs_shape

NameError: name 'env' is not defined