drlhp.pref_interface.PrefInterface 测试

drlhp.pref_interface.PrefInterface 测试#

from multiprocessing import Queue

import numpy as np
import termcolor

from drlhp.pref_db import Segment
from drlhp.pref_interface import PrefInterface

def send_segments(n_segments, seg_pipe):
    frame_stack = np.zeros((84, 84, 4))
    for i in range(n_segments):
        segment = Segment()
        for _ in range(25):
            segment.append(frame=frame_stack, reward=0)
        segment.finalise(seg_id=i)
        seg_pipe.put(segment)
        
p = PrefInterface(synthetic_prefs=True, max_segs=1000)
termcolor.cprint(p, 'red')
<drlhp.pref_interface.PrefInterface object at 0x7f75fc78d550>

检查片段是否正确存储在循环缓冲区中。

pi = PrefInterface(synthetic_prefs=True, max_segs=5)
pipe = Queue()
for i in range(5):
    pipe.put(i)
    pi.recv_segments(pipe)
np.testing.assert_array_equal(pi.segments, [0, 1, 2, 3, 4])
for i in range(5, 8):
    pipe.put(i)
    pi.recv_segments(pipe)
np.testing.assert_array_equal(pi.segments, [5, 6, 7, 3, 4])
for i in range(8, 11):
    pipe.put(i)
    pi.recv_segments(pipe)
np.testing.assert_array_equal(pi.segments, [10, 6, 7, 8, 9])