创建 Collect
环境包装器#
import logging
import sys
from pathlib import Path
from d2py.utils.log_config import config_logging
root_dir = Path(".").resolve()
sys.path.extend([str(root_dir.parents[3]/"tests/gym-multigrid")])
temp_dir = root_dir/"images"
temp_dir.mkdir(parents=True, exist_ok=True)
logger_name = "gym_multigrid"
logger = logging.getLogger(logger_name)
config_logging(f'{temp_dir}/{logger_name}.log', logger_name, maxBytes=50000, backupCount=2)
from pathlib import Path
import imageio
from gym_multigrid.envs.collect_game import CollectGameEnv
kwargs={
"size": 15,
"num_balls": [5,],
"agents_index": [1, 2, 3], # green, purple
"balls_index": [0,], # red, orange, yellow
"balls_reward": [1,],
"respawn": False,
}
origin_env = CollectGameEnv(**kwargs)
frames = [origin_env.render()]
obs, info = origin_env.reset()
while True:
actions = [origin_env.action_space.sample() for a in origin_env.agents]
obs, reward, terminated, truncated, info = origin_env.step(actions)
frames.append(origin_env.render())
if terminated or truncated:
logger.info(f"episode ended after {origin_env.step_count} steps")
logger.info(f"agents collected {origin_env.collected_balls} objects")
break
temp_dir = Path(temp_dir)
(temp_dir/"animations").mkdir(parents=True, exist_ok=True)
imageio.mimsave(temp_dir/f"animations/multigrid-collect.gif", frames, duration=0.5)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[2], line 3
1 from pathlib import Path
2 import imageio
----> 3 from gym_multigrid.envs.collect_game import CollectGameEnv
5 kwargs={
6 "size": 15,
7 "num_balls": [5,],
(...)
11 "respawn": False,
12 }
13 origin_env = CollectGameEnv(**kwargs)
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/__init__.py:1
----> 1 from gym_multigrid.envs.collect_game import *
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/envs/collect_game.py:4
1 import numpy as np
2 from numpy.typing import NDArray
----> 4 from gym_multigrid.multigrid import MultiGridEnv
5 from gym_multigrid.core.world import CollectWorld
6 from gym_multigrid.core.agent import CollectActions, Agent
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:21
15 from gym_multigrid.core.constants import *
18 MultiGridEnvT = TypeVar("MultiGridEnvT", bound="MultiGridEnv")
---> 21 class MultiGridEnv(gym.Env):
22 """
23 2D grid world game environment
24 """
26 metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 10}
File ~/work/pybook/pybook/tests/gym-multigrid/gym_multigrid/multigrid.py:399, in MultiGridEnv()
393 world_cell = self.grid.get(x, y)
395 return obs_cell is not None and obs_cell.type == world_cell.type
397 def step(
398 self, actions: list[int] | NDArray[np.int_]
--> 399 ) -> tuple[NDArray[np.int_], NDArray[np.float_], bool, bool, dict]:
400 self.step_count += 1
402 order = np.random.permutation(len(actions))
File /opt/hostedtoolcache/Python/3.12.7/x64/lib/python3.12/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.
from IPython import display
display.Image(temp_dir/f"animations/multigrid-collect.gif")
创建初始环境:
from dataclasses import dataclass
from typing import Any, SupportsFloat
import gymnasium as gym
from gymnasium import Wrapper, Env
from gymnasium.core import WrapperActType, WrapperObsType, RenderFrame, ObsType, ActType
from gym_multigrid.core.constants import TILE_PIXELS
@dataclass
class RewardWrapper(Wrapper):
env: Env[ObsType, ActType]
logger_name: str # 日志名称
def __post_init__(self):
# self.metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
self.logger = logging.getLogger(self.logger_name)
self.obs_shape = self.env.observation_space.shape
self.agents = self.env.agents
# def render(self, close=False, highlight=False, tile_size=TILE_PIXELS) -> RenderFrame | list[RenderFrame] | None:
# return self.env.render(close=close, highlight=highlight, tile_size=tile_size)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
obs, reward, terminated, truncated, info = self.env.step(action)
self.logger.info(f"obs, reward, terminated, truncated, info: {obs, reward, terminated, truncated, info}")
return obs, reward, terminated, truncated, info
env = RewardWrapper(origin_env, logger_name)
frames = [env.render()]
obs, info = env.reset()
while True:
actions = [env.action_space.sample() for a in env.agents]
obs, reward, terminated, truncated, info = env.step(actions)
frames.append(env.render())
if terminated or truncated:
logger.info(f"episode ended after {env.step_count} steps")
logger.info(f"agents collected {env.collected_balls} objects")
break
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[5], line 1
----> 1 env = RewardWrapper(origin_env, logger_name)
2 frames = [env.render()]
3 obs, info = env.reset()
NameError: name 'origin_env' is not defined
env.obs_shape
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[6], line 1
----> 1 env.obs_shape
NameError: name 'env' is not defined