예제 #1
0
    def test_ppo_bc(self):
        # Train bc model
        model_dir = self.temp_model_dir
        params_to_override = {
            "layouts": ['inverse_marshmallow_experiment'],
            "data_path": None,
            "epochs": 10
        }
        bc_params = get_bc_params(**params_to_override)
        train_bc_model(model_dir, bc_params)

        # Train rllib model
        config_updates = {
            "results_dir": self.temp_results_dir,
            "bc_schedule": [(0.0, 0.0), (8e3, 1.0)],
            "num_training_iters": 20,
            "bc_model_dir": model_dir,
            "evaluation_interval": 5,
            "verbose": False
        }
        results = ex.run(config_updates=config_updates,
                         options={
                             '--loglevel': 'ERROR'
                         }).result

        # Sanity check
        self.assertGreaterEqual(results['average_total_reward'],
                                self.min_performance)

        if self.compute_pickle:
            self.expected['test_ppo_bc'] = results

        # Reproducibility test
        if self.strict:
            self.assertDictEqual(results, self.expected['test_ppo_bc'])
예제 #2
0
 def test_training(self, lstm=False):       
     for name, bc_params in self.bc_params_to_test.items():
         bc_params["use_lstm"] = lstm
         model = train_bc_model(self.model_dir, bc_params)
         key_name = f'test_training_{name}'
         if lstm: key_name = "lstm_" + key_name
         self._compare_iterable_with_expected(key_name, self._model_forward(model, bc_params=bc_params, lstm=lstm))
    def test_training(self):
        model = train_bc_model(self.model_dir, self.bc_params)

        if self.compute_pickle:
            self.expected['test_training'] = model(self.dummy_input)
        if self.strict:
            self.assertTrue(
                np.allclose(model(self.dummy_input),
                            self.expected["test_training"]))
예제 #4
0
    def test_agent_metrics_evaluation(self, lstm=False):
        for name, bc_params in self.bc_params_to_test.items():
            bc_params["use_lstm"] = lstm
            bc_params["training_params"]["epochs"] = 1 if lstm else 20
            model = train_bc_model(self.model_dir, bc_params)
            metrics = evaluate_bc_model_metrics(model, bc_params)

            key_name = f'test_agent_metrics_evaluation_{name}'
            if lstm: key_name = "lstm_" + key_name
            self._compare_iterable_with_expected(key_name, metrics)
예제 #5
0
 def test_agent_evaluation(self, lstm=False):
     for name, bc_params in self.bc_params_to_test.items():
         bc_params["use_lstm"] = lstm
         bc_params["training_params"]["epochs"] = 1 if lstm else 20
         model = train_bc_model(self.model_dir, bc_params)
         results = evaluate_bc_model(model, bc_params)
         key_name = f'test_agent_evaluation_{name}'
         if lstm: key_name = "lstm_" + key_name
         # Sanity Check
         self.assertGreaterEqual(results, self.min_performance)
         self._compare_iterable_with_expected(key_name, results)
    def test_agent_evaluation(self):
        self.bc_params["training_params"]["epochs"] = 20
        model = train_bc_model(self.model_dir, self.bc_params)
        results = evaluate_bc_model(model, self.bc_params)

        # Sanity Check
        self.assertGreaterEqual(results, self.min_performance)

        if self.compute_pickle:
            self.expected['test_agent_evaluation'] = results
        if self.strict:
            self.assertAlmostEqual(results,
                                   self.expected['test_agent_evaluation'])
예제 #7
0
    def test_behavior_cloning_policy(self, lstm=False):
        for name, bc_params in self.bc_params_to_test.items():
            bc_params["use_lstm"] = lstm
            model = train_bc_model(self.model_dir, bc_params)
            policy = BehaviorCloningPolicy.from_model(model, bc_params, stochastic=False)
            key_name = f'test_behavior_cloning_policy_actions_{name}'
            if lstm: key_name = "lstm_" + key_name
            observations = get_dummy_input(bc_params)
            actions_output = policy.compute_actions(observations, self.initial_states)
            self._compare_iterable_with_expected(key_name, actions_output)

            if bc_params["predict_orders"]:
                key_name = f'test_behavior_cloning_policy_orders_{name}'
                if lstm: key_name = "lstm_" + key_name
                orders_output = policy.predict_orders(observations, self.initial_states)
                self._compare_iterable_with_expected(key_name, orders_output)
            else:
                self.assertRaises(AssertionError, policy.predict_orders, observations, self.initial_states)