drlhp.utils
测试#
import socket
from math import ceil
import numpy as np
from drlhp.utils import RunningStat, batch_iter, get_port_range
drlhp.utils.RunningStat
#
代码修改自 running_stat.py
for shp in ((), (3, ), (3, 4)):
li = []
rs = RunningStat(shp)
for i in range(5):
val = np.random.randn(*shp)
rs.push(val)
li.append(val)
m = np.mean(li, axis=0)
assert np.allclose(rs.mean, m)
if i == 0:
continue
# ddof=1 => calculate unbiased sample variance
v = np.var(li, ddof=1, axis=0)
assert np.allclose(rs.var, v)
drlhp.utils.get_port_range()
#
测试 1:如果从端口 60000
开始请求 3 个端口(这些端口上不应该有服务在监听),应当得到以下结果:
ports = get_port_range(60000, 3)
ports # 60000, 60001 and 60002
[60000, 60001, 60002]
测试2:如果在端口 60000
上设置监听,然后请求与测试 1 中相同的端口,函数应该跳过 60000
并给出下一个三个端口。
s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s1.bind(("127.0.0.1", 60000))
ports = get_port_range(60000, 3)
ports # 60001, 60002, 60003
Warning: port 60000 already in use
[60001, 60002, 60003]
测试3:如果在端口 60002
上设置监听,函数应该意识到它无法从 60000
开始分配连续的范围,并应该给出从 60003
开始的范围。
s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s2.bind(("127.0.0.1", 60002))
ports = get_port_range(60000, 3)
ports # 60003, 60004, 60005
Warning: port 60000 already in use
Warning: port 60002 already in use
[60003, 60004, 60005]
drlhp.utils.batch_iter()
#
检查 drlhp.utils.batch_iter()
是否返回了完全正确的数据。
l1 = list(range(16))
l2 = list(range(15))
l3 = list(range(13))
for l in [l1, l2, l3]:
for shuffle in [True, False]:
expected_data = l
actual_data = set()
expected_n_batches = ceil(len(l) / 4)
actual_n_batches = 0
for batch_n, x in enumerate(batch_iter(l,
batch_size=4,
shuffle=shuffle)):
if batch_n == expected_n_batches - 1 and len(l) % 4 != 0:
assert len(x) == len(l) % 4
else:
assert len(x) == 4
assert len(actual_data.intersection(set(x))) == 0
actual_data = actual_data.union(set(x))
actual_n_batches += 1
assert actual_n_batches == expected_n_batches
np.testing.assert_array_equal(list(actual_data), expected_data)
检查 shuffle=True
是否返回相同的数据,但顺序不同。
expected_data = list(range(16))
actual_data = []
for x in batch_iter(expected_data, batch_size=4, shuffle=True):
actual_data.extend(x)
assert len(actual_data) == len(expected_data)
assert set(actual_data) == set(expected_data)
try:
np.testing.assert_array_equal(actual_data, expected_data)
except AssertionError:
print("满足要求")
满足要求
检查连续调用是否以不同的顺序进行洗牌。
data = list(range(16))
out1 = []
for x in batch_iter(data, batch_size=4, shuffle=True):
out1.extend(x)
out2 = []
for x in batch_iter(data, batch_size=4, shuffle=True):
out2.extend(x)
assert set(out1) == set(out2)
try:
np.testing.assert_array_equal(out1, out2)
except AssertionError:
print("满足要求")
满足要求