def test_clone(self): env = GymEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4)) policy = CategoricalLSTMPolicy(env_spec=env.spec) policy_clone = policy.clone('CategoricalLSTMPolicyClone') assert policy.env_spec == policy_clone.env_spec for cloned_param, param in zip(policy_clone.parameters.values(), policy.parameters.values()): assert np.array_equal(cloned_param, param)
def test_clone(self): env = TfEnv(DummyDiscreteEnv(obs_dim=(1, ), action_dim=1)) with mock.patch(('garage.tf.policies.' 'categorical_lstm_policy.LSTMModel'), new=SimpleLSTMModel): policy = CategoricalLSTMPolicy(env_spec=env.spec, state_include_action=False) policy_cloned = policy.clone('cloned_policy') assert policy_cloned.name == 'cloned_policy' assert np.array_equal(policy.get_param_values(), policy_cloned.get_param_values())
def test_clone(self): env = GarageEnv(DummyDiscreteEnv(obs_dim=(10, ), action_dim=4)) policy = CategoricalLSTMPolicy(env_spec=env.spec) policy_clone = policy.clone('CategoricalLSTMPolicyClone') assert policy.env_spec == policy_clone.env_spec