def test_rollout_v1(): m = models.create('Rollout.v1') # batch_size * in_channels * 8 * 8 input = torch.randn(16, 23, 8, 8) # batch_size * num_move_planes * 8 * 8 output = m(input) assert output.shape == (16, 73, 8, 8)
def test_value_v0(): m = models.create('Value.v0') # batch_size * in_channels * 8 * 8 input = torch.randn(16, 23, 8, 8) # batch_size * 1 output = m(input) assert output.shape == (16, 1)
def test_policy_v0(): m = models.create('Policy.v0') # batch_size * in_channels * 8 * 8 input = torch.randn(16, 23, 8, 8) # batch_size * num_move_planes * 8 * 8 output = m(input) assert output.shape == (16, 73, 8, 8)
def test_value_v2(): m = models.create('Value.v2') assert m.batch_norm # batch_size * in_channels * 8 * 8 input = torch.randn(16, 21, 8, 8) # batch_size * 1 output = m(input) assert output.shape == (16, 1)
def test_res_v0(): tower, policy, value = models.create('ResNet.v0') # (batch_size, in_channels, 8, 8) input = torch.randn(16, 21, 8, 8) # (batch_size, num_move_planes * 8 * 8) policy_output = policy(tower(input)) assert policy_output.shape == (16, NUM_MOVE_PLANES * 8 * 8) # (batch_size, 1) value_output = value(tower(input)) assert value_output.shape == (16, 1)