示例#1
0
    def testAttentionCritic(self):
        critic = AttentionCritic([(5, 3), (5, 2)], attend_heads=4)
        sample_frames = \
            [{AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
              AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
              AgentKey(0, '0-3'): AgentReplayFrame([2, 0, 3, 0, 2], [1, 0, 0], 1, False, [3, 0, 1, 3, 4]),
              AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])},
             {AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
              AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
              AgentKey(0, '0-3'): AgentReplayFrame([2, 0, 3, 0, 2], [1, 0, 0], 0, True, [3, 0, 1, 3, 4]),
              AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])},
             {AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
              AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
              AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])}]

        sample_frames: Dict[AgentKey,
                            BatchedAgentReplayFrame] = preprocess_to_batch(
                                sample_frames)

        results: Dict[AgentKey, List[float]] = critic.forward(sample_frames)

        print(results)

        for k in sample_frames.keys():
            self.assertTrue(k in results)
示例#2
0
from utils.critics import AttentionCritic
from utils.core import *
from utils.buffer import AgentReplayFrame
from torchviz import *

critic = AttentionCritic([(5, 3), (5, 2)], attend_heads=4)
sample_frames = \
    [{AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
      AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
      AgentKey(0, '0-3'): AgentReplayFrame([2, 0, 3, 0, 2], [1, 0, 0], 1, False, [3, 0, 1, 3, 4]),
      AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])},
     {AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
      AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
      AgentKey(0, '0-3'): AgentReplayFrame([2, 0, 3, 0, 2], [1, 0, 0], 0, True, [3, 0, 1, 3, 4]),
      AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])},
     {AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
      AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
      AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])}]

sample_frames: Dict[AgentKey, BatchedAgentReplayFrame] = preprocess_to_batch(
    sample_frames)

results: Dict[AgentKey, List[float]] = critic.forward(sample_frames)

print(results)

dot = make_dot(results[AgentKey(1, '0-1')][0].mean(),
               params=dict(critic.named_parameters()))
dot.format = "png"
dot.render("myfile")