コード例 #1
0
def test_augmentation_pipeline(batch_size, observation_shape):
    aug1 = DummyAugmentation()
    aug1.transform = Mock(side_effect=lambda x: x + 0.1)

    aug2 = DummyAugmentation()
    aug2.transform = Mock(side_effect=lambda x: x + 0.2)

    aug = AugmentationPipeline([aug1])
    aug.append(aug2)

    x = np.random.random((batch_size, *observation_shape))
    y = aug.transform(x)

    aug1.transform.assert_called_once()
    aug2.transform.assert_called_once()
    assert np.allclose(y, x + 0.3)

    assert aug.get_augmentation_types() == ['dummy', 'dummy']
    assert aug.get_augmentation_params() == [{'param': 0.1}, {'param': 0.1}]
コード例 #2
0
@pytest.mark.parametrize('gamma', [0.99])
@pytest.mark.parametrize('tau', [0.05])
@pytest.mark.parametrize('n_critics', [2])
@pytest.mark.parametrize('bootstrap', [False])
@pytest.mark.parametrize('share_encoder', [False, True])
@pytest.mark.parametrize('initial_temperature', [1.0])
@pytest.mark.parametrize('initial_alpha', [1.0])
@pytest.mark.parametrize('alpha_threshold', [0.05])
@pytest.mark.parametrize('lam', [0.75])
@pytest.mark.parametrize('n_action_samples', [4])
@pytest.mark.parametrize('mmd_sigma', [20.0])
@pytest.mark.parametrize('eps', [1e-8])
@pytest.mark.parametrize('use_batch_norm', [True, False])
@pytest.mark.parametrize('q_func_type', ['mean', 'qr', 'iqn', 'fqf'])
@pytest.mark.parametrize('scaler', [None, DummyScaler()])
@pytest.mark.parametrize('augmentation', [AugmentationPipeline()])
@pytest.mark.parametrize('n_augmentations', [1])
@pytest.mark.parametrize('encoder_params', [{}])
def test_bear_impl(observation_shape, action_size, actor_learning_rate,
                   critic_learning_rate, imitator_learning_rate,
                   temp_learning_rate, alpha_learning_rate, gamma, tau,
                   n_critics, bootstrap, share_encoder, initial_temperature,
                   initial_alpha, alpha_threshold, lam, n_action_samples,
                   mmd_sigma, eps, use_batch_norm, q_func_type, scaler,
                   augmentation, n_augmentations, encoder_params):
    impl = BEARImpl(observation_shape,
                    action_size,
                    actor_learning_rate,
                    critic_learning_rate,
                    imitator_learning_rate,
                    temp_learning_rate,