drlhp.utils 测试#

import socket
from math import ceil
import numpy as np
from drlhp.utils import RunningStat, batch_iter, get_port_range

drlhp.utils.RunningStat#

代码修改自 running_stat.py

for shp in ((), (3, ), (3, 4)):
    li = []
    rs = RunningStat(shp)
    for i in range(5):
        val = np.random.randn(*shp)
        rs.push(val)
        li.append(val)
        m = np.mean(li, axis=0)
        assert np.allclose(rs.mean, m)
        if i == 0:
            continue
        # ddof=1 => calculate unbiased sample variance
        v = np.var(li, ddof=1, axis=0)
        assert np.allclose(rs.var, v)

drlhp.utils.get_port_range()#

测试 1:如果从端口 60000 开始请求 3 个端口(这些端口上不应该有服务在监听),应当得到以下结果:

ports = get_port_range(60000, 3)
ports # 60000, 60001 and 60002
[60000, 60001, 60002]

测试2:如果在端口 60000 上设置监听,然后请求与测试 1 中相同的端口,函数应该跳过 60000 并给出下一个三个端口。

s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s1.bind(("127.0.0.1", 60000))
ports = get_port_range(60000, 3)
ports # 60001, 60002, 60003
Warning: port 60000 already in use
[60001, 60002, 60003]

测试3:如果在端口 60002 上设置监听,函数应该意识到它无法从 60000 开始分配连续的范围,并应该给出从 60003 开始的范围。

s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s2.bind(("127.0.0.1", 60002))
ports = get_port_range(60000, 3)
ports # 60003, 60004, 60005
Warning: port 60000 already in use
Warning: port 60002 already in use
[60003, 60004, 60005]

drlhp.utils.batch_iter()#

检查 drlhp.utils.batch_iter() 是否返回了完全正确的数据。

l1 = list(range(16))
l2 = list(range(15))
l3 = list(range(13))
for l in [l1, l2, l3]:
    for shuffle in [True, False]:
        expected_data = l
        actual_data = set()
        expected_n_batches = ceil(len(l) / 4)
        actual_n_batches = 0
        for batch_n, x in enumerate(batch_iter(l,
                                                batch_size=4,
                                                shuffle=shuffle)):
            if batch_n == expected_n_batches - 1 and len(l) % 4 != 0:
                assert len(x) == len(l) % 4
            else:
                assert len(x) == 4
            assert len(actual_data.intersection(set(x))) == 0
            actual_data = actual_data.union(set(x))
            actual_n_batches += 1
        assert actual_n_batches == expected_n_batches
        np.testing.assert_array_equal(list(actual_data), expected_data)

检查 shuffle=True 是否返回相同的数据,但顺序不同。

expected_data = list(range(16))
actual_data = []
for x in batch_iter(expected_data, batch_size=4, shuffle=True):
    actual_data.extend(x)
assert len(actual_data) == len(expected_data)
assert set(actual_data) == set(expected_data)
try:
    np.testing.assert_array_equal(actual_data, expected_data)
except AssertionError:
    print("满足要求")
满足要求

检查连续调用是否以不同的顺序进行洗牌。

data = list(range(16))
out1 = []
for x in batch_iter(data, batch_size=4, shuffle=True):
    out1.extend(x)
out2 = []
for x in batch_iter(data, batch_size=4, shuffle=True):
    out2.extend(x)
assert set(out1) == set(out2)
try:
    np.testing.assert_array_equal(out1, out2)
except AssertionError:
    print("满足要求")
满足要求