drlhp.utils 测试#

import socket
from math import ceil
import numpy as np
from drlhp.utils import RunningStat, batch_iter, get_port_range
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 4
      2 from math import ceil
      3 import numpy as np
----> 4 from drlhp.utils import RunningStat, batch_iter, get_port_range

File ~/work/pybook/pybook/doc/libs/drlhp/drlhp/utils.py:7
      4 import time
      6 import numpy as np
----> 7 import pyglet
      8 import pdb
      9 import sys

ModuleNotFoundError: No module named 'pyglet'

drlhp.utils.RunningStat#

代码修改自 running_stat.py

for shp in ((), (3, ), (3, 4)):
    li = []
    rs = RunningStat(shp)
    for i in range(5):
        val = np.random.randn(*shp)
        rs.push(val)
        li.append(val)
        m = np.mean(li, axis=0)
        assert np.allclose(rs.mean, m)
        if i == 0:
            continue
        # ddof=1 => calculate unbiased sample variance
        v = np.var(li, ddof=1, axis=0)
        assert np.allclose(rs.var, v)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[2], line 3
      1 for shp in ((), (3, ), (3, 4)):
      2     li = []
----> 3     rs = RunningStat(shp)
      4     for i in range(5):
      5         val = np.random.randn(*shp)

NameError: name 'RunningStat' is not defined

drlhp.utils.get_port_range()#

测试 1:如果从端口 60000 开始请求 3 个端口(这些端口上不应该有服务在监听),应当得到以下结果:

ports = get_port_range(60000, 3)
ports # 60000, 60001 and 60002
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[3], line 1
----> 1 ports = get_port_range(60000, 3)
      2 ports # 60000, 60001 and 60002

NameError: name 'get_port_range' is not defined

测试2:如果在端口 60000 上设置监听,然后请求与测试 1 中相同的端口,函数应该跳过 60000 并给出下一个三个端口。

s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s1.bind(("127.0.0.1", 60000))
ports = get_port_range(60000, 3)
ports # 60001, 60002, 60003
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[4], line 3
      1 s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      2 s1.bind(("127.0.0.1", 60000))
----> 3 ports = get_port_range(60000, 3)
      4 ports # 60001, 60002, 60003

NameError: name 'get_port_range' is not defined

测试3:如果在端口 60002 上设置监听,函数应该意识到它无法从 60000 开始分配连续的范围,并应该给出从 60003 开始的范围。

s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s2.bind(("127.0.0.1", 60002))
ports = get_port_range(60000, 3)
ports # 60003, 60004, 60005
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[5], line 3
      1 s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      2 s2.bind(("127.0.0.1", 60002))
----> 3 ports = get_port_range(60000, 3)
      4 ports # 60003, 60004, 60005

NameError: name 'get_port_range' is not defined

drlhp.utils.batch_iter()#

检查 drlhp.utils.batch_iter() 是否返回了完全正确的数据。

l1 = list(range(16))
l2 = list(range(15))
l3 = list(range(13))
for l in [l1, l2, l3]:
    for shuffle in [True, False]:
        expected_data = l
        actual_data = set()
        expected_n_batches = ceil(len(l) / 4)
        actual_n_batches = 0
        for batch_n, x in enumerate(batch_iter(l,
                                                batch_size=4,
                                                shuffle=shuffle)):
            if batch_n == expected_n_batches - 1 and len(l) % 4 != 0:
                assert len(x) == len(l) % 4
            else:
                assert len(x) == 4
            assert len(actual_data.intersection(set(x))) == 0
            actual_data = actual_data.union(set(x))
            actual_n_batches += 1
        assert actual_n_batches == expected_n_batches
        np.testing.assert_array_equal(list(actual_data), expected_data)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[6], line 10
      8 expected_n_batches = ceil(len(l) / 4)
      9 actual_n_batches = 0
---> 10 for batch_n, x in enumerate(batch_iter(l,
     11                                         batch_size=4,
     12                                         shuffle=shuffle)):
     13     if batch_n == expected_n_batches - 1 and len(l) % 4 != 0:
     14         assert len(x) == len(l) % 4

NameError: name 'batch_iter' is not defined

检查 shuffle=True 是否返回相同的数据,但顺序不同。

expected_data = list(range(16))
actual_data = []
for x in batch_iter(expected_data, batch_size=4, shuffle=True):
    actual_data.extend(x)
assert len(actual_data) == len(expected_data)
assert set(actual_data) == set(expected_data)
try:
    np.testing.assert_array_equal(actual_data, expected_data)
except AssertionError:
    print("满足要求")
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[7], line 3
      1 expected_data = list(range(16))
      2 actual_data = []
----> 3 for x in batch_iter(expected_data, batch_size=4, shuffle=True):
      4     actual_data.extend(x)
      5 assert len(actual_data) == len(expected_data)

NameError: name 'batch_iter' is not defined

检查连续调用是否以不同的顺序进行洗牌。

data = list(range(16))
out1 = []
for x in batch_iter(data, batch_size=4, shuffle=True):
    out1.extend(x)
out2 = []
for x in batch_iter(data, batch_size=4, shuffle=True):
    out2.extend(x)
assert set(out1) == set(out2)
try:
    np.testing.assert_array_equal(out1, out2)
except AssertionError:
    print("满足要求")
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[8], line 3
      1 data = list(range(16))
      2 out1 = []
----> 3 for x in batch_iter(data, batch_size=4, shuffle=True):
      4     out1.extend(x)
      5 out2 = []

NameError: name 'batch_iter' is not defined