Пример #1
0
def test_to_leela_weights(weight_file):
    n = Network(19, 18, 16, 3)
    _, tmp = tempfile.mkstemp()
    n.to_leela_weights(tmp)

    with open(tmp, 'r') as tmpfile:
        tmplines = tmpfile.readlines()
        weightlines = weight_file.readlines()
        assert len(tmplines) == len(weightlines)
        for l1, l2 in zip(tmplines, weightlines):
            t1 = torch.tensor([float(n) for n in l1.split()])
            t2 = torch.tensor([float(n) for n in l2.split()])
            assert t1.size() == t2.size()

    os.remove(tmp)
Пример #2
0
def test_network(board_size, in_channels, residual_channels, residual_layers,
                 input_size, pol_output_size, val_output_size):
    n = Network(board_size, in_channels, residual_channels, residual_layers)
    (pol, val), (target_pol,
                 target_val) = n(torch.randn(*input_size),
                                 torch.randint(19 * 19 + 1, (input_size[0], )),
                                 torch.ones((input_size[0], )))
    assert pol.size() == pol_output_size
    assert val.size() == val_output_size
    assert target_pol.size() == (input_size[0], )
    assert target_val.size() == (input_size[0], )
Пример #3
0
def test_network(
    board_size,
    in_channels,
    residual_channels,
    residual_layers,
    input_size,
    pol_output_size,
    val_output_size,
):
    n = Network(board_size, in_channels, residual_channels, residual_layers)
    pol, val = n(torch.randn(*input_size))
    assert pol.size() == pol_output_size
    assert val.size() == val_output_size