Esempio n. 1
0
def train(algo_name,
          params,
          dataset_path,
          model_save_path,
          experiment_name=None,
          with_timestamp=True,
          logdir='d3rlpy_logs',
          prev_model_path=None,
          test_size=0.2):
    dataset = MDPDataset.load(dataset_path)
    train_data, test_data = train_test_split(dataset, test_size=test_size)

    # train
    algo = create_algo(algo_name, dataset.is_action_discrete(), **params)
    algo.fit(train_data,
             experiment_name=experiment_name,
             with_timestamp=with_timestamp,
             logdir=logdir,
             save_interval=1000000)  # never save models for now

    # save final model
    algo.save_model(model_save_path)

    # evaluate
    scores = _evaluate(algo, test_data, dataset.is_action_discrete())

    # compare previous model
    if prev_model_path:
        base_algo = create_algo(algo_name, **params)
        base_algo.load_model(prev_model_path)
        score = _comapre(algo, base_algo, test_data,
                         dataset.is_action_discrete())
        scores['algo_action_diff'] = score

    return scores
Esempio n. 2
0
    def get_model(self):
        from flask import jsonify, send_file

        # initialize algorithm
        algo = create_algo(self.algo_name, self.dataset.is_action_discrete(),
                           **self.algo_params)
        algo.create_impl(self.dataset.get_observation_shape(),
                         self.dataset.get_action_size())

        # load latest model
        trial = self.n_trials - 1
        while trial >= 0:
            model_path = self.model_save_path_tmpl % trial
            if os.path.exists(model_path):
                algo.load_model(model_path)
                break
            trial -= 1

        if trial < 0:
            # return error
            from flask import jsonify
            return jsonify({'status': 'empty'}), 500

        # save policy
        policy_path = os.path.join(self.dir_path, 'policy.pt')
        algo.save_policy(policy_path)

        # send back policy data
        res = send_file(policy_path,
                        as_attachment=True,
                        attachment_filename='policy.pt')

        return res
Esempio n. 3
0
def train(algo_name,
          params,
          dataset_path,
          experiment_name=None,
          logdir='d3rlpy_logs'):
    # prepare dataset
    dataset = MDPDataset.load(dataset_path)
    train_data, test_data = train_test_split(dataset, test_size=0.2)

    # evaluate
    scorers = _get_scorers(dataset.is_action_discrete())

    # train
    algo = create_algo(algo_name, dataset.is_action_discrete(), **params)
    algo.fit(train_data,
             eval_episodes=test_data,
             scorers=scorers,
             experiment_name=experiment_name,
             with_timestamp=False,
             logdir=logdir,
             save_interval=1,
             show_progress=False,
             tensorboard=False)

    return True
Esempio n. 4
0
def train(algo_name,
          params,
          dataset_path,
          experiment_name=None,
          logdir="d3rlpy_logs"):
    # prepare dataset
    dataset = MDPDataset.load(dataset_path)
    train_data, test_data = train_test_split(dataset, test_size=0.2)

    # get dataset statistics
    stats = dataset.compute_stats()

    # evaluate
    scorers = _get_scorers(dataset.is_action_discrete(), stats)

    # add action scaler if continuous action-space
    if not dataset.is_action_discrete():
        params["action_scaler"] = "min_max"

    # train
    algo = create_algo(algo_name, dataset.is_action_discrete(), **params)
    algo.fit(
        train_data,
        n_steps=params["n_epochs"] * params["n_steps_per_epoch"],
        n_steps_per_epoch=params["n_steps_per_epoch"],
        eval_episodes=test_data,
        scorers=scorers,
        experiment_name=experiment_name,
        with_timestamp=False,
        logdir=logdir,
        save_interval=1,
        show_progress=False,
    )

    return True
Esempio n. 5
0
    def __init__(self,
                 algo_name,
                 algo_params,
                 dataset,
                 dir_path='d3rlpy_logs/worker',
                 with_timestamp=True,
                 model_path=None):
        from flask import Flask
        self.algo_name = algo_name
        self.algo_params = algo_params
        self.dataset = dataset
        self.n_trials = 0

        # setup flask server
        self.app = Flask(__name__)
        self.app.add_url_rule('/train',
                              'train',
                              self.train_algo,
                              methods=['POST'])
        self.app.add_url_rule('/data',
                              'data',
                              self.append_data,
                              methods=['POST'])
        self.app.add_url_rule('/model',
                              'model',
                              self.get_model,
                              methods=['GET'])
        self.app.add_url_rule('/status',
                              'status',
                              self.get_status,
                              methods=['GET'])

        # prepare directory
        if with_timestamp:
            dir_path += '_' + datetime.now().strftime('%Y%m%d%H%M%S')
        self.dir_path = os.path.abspath(dir_path)
        os.makedirs(dir_path)

        # save initial dataset
        self.dataset_path = os.path.join(self.dir_path, 'dataset.h5')
        dataset.dump(self.dataset_path)

        self.model_save_path_tmpl = os.path.join(self.dir_path, 'model_%d.pt')
        self.experiment_name_tmpl = 'worker_training_%d'
        self.train_uid = None
        self.latest_metrics = {}

        # make initial model
        if model_path:
            algo = create_algo(algo_name, **algo_params)
            algo.create_impl(dataset.get_observation_shape(),
                             dataset.get_action_size())
            algo.load_model(model_path)
            algo.save_model(self.model_save_path_tmpl % 0)
            self.n_trials += 1
        else:
            self._dispatch_training_job()