drlhp.PrefDB 测试

drlhp.PrefDB 测试#

import numpy as np
from drlhp.pref_db import PrefDB

测试偏好数据库是否真正区分相似段(即检查其哈希函数是否按预期工作)。

p = PrefDB(maxlen=5)
s1 = np.ones((25, 84, 84, 4))
s2 = np.ones((25, 84, 84, 4))
s2[12][24][24][2] = 0
p.append(s1, s2, [1.0, 0.0])
assert len(p.segments) == 2

执行多次追加/删除操作,并检查偏好和片段的数量是否始终如预期。

p = PrefDB(maxlen=10)

s1 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4))
s2 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4))
p.append(s1, s2, [1.0, 0.0])
assert len(p.segments) == 2
assert len(p.prefs) == 1

p.append(s1, s2, [0.0, 1.0])
assert len(p.segments) == 2
assert len(p.prefs) == 2

s1 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4))
p.append(s1, s2, [1.0, 0.0])
assert len(p.segments) == 3
assert len(p.prefs) == 3

s2 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4))
p.append(s1, s2, [1.0, 0.0])
assert len(p.segments) == 4
assert len(p.prefs) == 4

s1 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4))
s2 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4))
p.append(s1, s2, [1.0, 0.0])
assert len(p.segments) == 6
assert len(p.prefs) == 5

prefs_pre = list(p.prefs)
p.del_first()
assert len(p.prefs) == 4
assert p.prefs == prefs_pre[1:]
# These segments were also used by the second preference,
# so the number of segments shouldn't have decreased
assert len(p.segments) == 6

p.del_first()
assert len(p.prefs) == 3
# One of the segments just deleted was only used by the first two
# preferences, so the length should have shrunk by one
assert len(p.segments) == 5

p.del_first()
assert len(p.prefs) == 2
# Another one should bite the dust...
assert len(p.segments) == 4

p.del_first()
assert len(p.prefs) == 1
assert len(p.segments) == 2

p.del_first()
assert len(p.prefs) == 0
assert len(p.segments) == 0

测试循环

p = PrefDB(maxlen=2)

p.append(0, 1, 10)
assert len(p) == 1
p.append(2, 3, 11)
assert len(p) == 2
p.append(4, 5, 12)
assert len(p) == 2

assert len(p.segments) == 4
assert (2 in p.segments.values())
assert (3 in p.segments.values())
assert (4 in p.segments.values())
assert (5 in p.segments.values())

assert p.prefs[0][2] == 11
assert p.prefs[1][2] == 12