drlhp.utils
测试#
import socket
from math import ceil
import numpy as np
from drlhp.utils import RunningStat, batch_iter, get_port_range
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[1], line 4
2 from math import ceil
3 import numpy as np
----> 4 from drlhp.utils import RunningStat, batch_iter, get_port_range
File ~/work/pybook/pybook/doc/libs/drlhp/drlhp/utils.py:7
4 import time
6 import numpy as np
----> 7 import pyglet
8 import pdb
9 import sys
ModuleNotFoundError: No module named 'pyglet'
drlhp.utils.RunningStat
#
代码修改自 running_stat.py
for shp in ((), (3, ), (3, 4)):
li = []
rs = RunningStat(shp)
for i in range(5):
val = np.random.randn(*shp)
rs.push(val)
li.append(val)
m = np.mean(li, axis=0)
assert np.allclose(rs.mean, m)
if i == 0:
continue
# ddof=1 => calculate unbiased sample variance
v = np.var(li, ddof=1, axis=0)
assert np.allclose(rs.var, v)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[2], line 3
1 for shp in ((), (3, ), (3, 4)):
2 li = []
----> 3 rs = RunningStat(shp)
4 for i in range(5):
5 val = np.random.randn(*shp)
NameError: name 'RunningStat' is not defined
drlhp.utils.get_port_range()
#
测试 1:如果从端口 60000
开始请求 3 个端口(这些端口上不应该有服务在监听),应当得到以下结果:
ports = get_port_range(60000, 3)
ports # 60000, 60001 and 60002
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[3], line 1
----> 1 ports = get_port_range(60000, 3)
2 ports # 60000, 60001 and 60002
NameError: name 'get_port_range' is not defined
测试2:如果在端口 60000
上设置监听,然后请求与测试 1 中相同的端口,函数应该跳过 60000
并给出下一个三个端口。
s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s1.bind(("127.0.0.1", 60000))
ports = get_port_range(60000, 3)
ports # 60001, 60002, 60003
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[4], line 3
1 s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2 s1.bind(("127.0.0.1", 60000))
----> 3 ports = get_port_range(60000, 3)
4 ports # 60001, 60002, 60003
NameError: name 'get_port_range' is not defined
测试3:如果在端口 60002
上设置监听,函数应该意识到它无法从 60000
开始分配连续的范围,并应该给出从 60003
开始的范围。
s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s2.bind(("127.0.0.1", 60002))
ports = get_port_range(60000, 3)
ports # 60003, 60004, 60005
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[5], line 3
1 s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2 s2.bind(("127.0.0.1", 60002))
----> 3 ports = get_port_range(60000, 3)
4 ports # 60003, 60004, 60005
NameError: name 'get_port_range' is not defined
drlhp.utils.batch_iter()
#
检查 drlhp.utils.batch_iter()
是否返回了完全正确的数据。
l1 = list(range(16))
l2 = list(range(15))
l3 = list(range(13))
for l in [l1, l2, l3]:
for shuffle in [True, False]:
expected_data = l
actual_data = set()
expected_n_batches = ceil(len(l) / 4)
actual_n_batches = 0
for batch_n, x in enumerate(batch_iter(l,
batch_size=4,
shuffle=shuffle)):
if batch_n == expected_n_batches - 1 and len(l) % 4 != 0:
assert len(x) == len(l) % 4
else:
assert len(x) == 4
assert len(actual_data.intersection(set(x))) == 0
actual_data = actual_data.union(set(x))
actual_n_batches += 1
assert actual_n_batches == expected_n_batches
np.testing.assert_array_equal(list(actual_data), expected_data)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[6], line 10
8 expected_n_batches = ceil(len(l) / 4)
9 actual_n_batches = 0
---> 10 for batch_n, x in enumerate(batch_iter(l,
11 batch_size=4,
12 shuffle=shuffle)):
13 if batch_n == expected_n_batches - 1 and len(l) % 4 != 0:
14 assert len(x) == len(l) % 4
NameError: name 'batch_iter' is not defined
检查 shuffle=True
是否返回相同的数据,但顺序不同。
expected_data = list(range(16))
actual_data = []
for x in batch_iter(expected_data, batch_size=4, shuffle=True):
actual_data.extend(x)
assert len(actual_data) == len(expected_data)
assert set(actual_data) == set(expected_data)
try:
np.testing.assert_array_equal(actual_data, expected_data)
except AssertionError:
print("满足要求")
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[7], line 3
1 expected_data = list(range(16))
2 actual_data = []
----> 3 for x in batch_iter(expected_data, batch_size=4, shuffle=True):
4 actual_data.extend(x)
5 assert len(actual_data) == len(expected_data)
NameError: name 'batch_iter' is not defined
检查连续调用是否以不同的顺序进行洗牌。
data = list(range(16))
out1 = []
for x in batch_iter(data, batch_size=4, shuffle=True):
out1.extend(x)
out2 = []
for x in batch_iter(data, batch_size=4, shuffle=True):
out2.extend(x)
assert set(out1) == set(out2)
try:
np.testing.assert_array_equal(out1, out2)
except AssertionError:
print("满足要求")
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[8], line 3
1 data = list(range(16))
2 out1 = []
----> 3 for x in batch_iter(data, batch_size=4, shuffle=True):
4 out1.extend(x)
5 out2 = []
NameError: name 'batch_iter' is not defined