Beispiel #1
0
    def run(self):
        self._set_io_size()
        self._build_session()

        timer_dict = OrderedDict()
        timer_dict['Program Start'] = time.time()
        rolling_stats = defaultdict(list)
        training_return = {}

        with self._session as sess:
            self._build_models()
            self._init_whitening_stats()

            data_dict = defaultdict(list)
            while True:
                if (self.timesteps_so_far % self.args.print_frequency) == 0:
                    timer_dict['** Program Total Time **'] = time.time()
                    training_return['stats'] = {}
                    for key in rolling_stats:
                        training_return['stats'][key] = np.mean(
                            rolling_stats[key][-self.args.print_frequency:])

                    if 'mean_rewards' in training_return['stats']:
                        training_return['stats']['mean_rewards'] *= \
                           self.args.episode_length

                    training_return['iteration'] = \
                        self.timesteps_so_far//self.args.print_frequency
                    training_return['totalsteps'] = self.timesteps_so_far

                    log_results(training_return, timer_dict)

                temp_dict = self._play()
                for key in temp_dict:
                    data_dict[key].append(temp_dict[key])

                if self.timesteps_so_far >= self.args.dqn_training_start and \
                    (self.timesteps_so_far % self.args.dqn_train_freq == 0):
                    stats, _ = self._train(data_dict)

                    for key in stats:
                        rolling_stats[key].append(stats[key])
                    # log and print the results

                    # clear data_dict
                    data_dict = defaultdict(list)

                if self.timesteps_so_far >= self.args.train_dqn_steps:
                    self._saver.save(
                        self._session,
                        osp.join(
                            logger._get_path(),
                            "pretrained_model_{}".format(self._name_scope)))
                    break
                else:
                    self.timesteps_so_far += 1

            final_weights = self._get_weights()

        return final_weights, self._environments_cache
Beispiel #2
0
    def run(self):
        self._set_io_size()
        self._build_session()

        timer_dict = OrderedDict()
        timer_dict['Program Start'] = time.time()

        self._build_environments()

        with self._session as sess:
            self._build_models()
            if self.weights is not None:
                self._set_weights(self.weights)
            self._network['explore'].set_environments(self._environments_cache)
            self._network['explore'].action_size = self._action_size
            buffer = self._generate_trajectories()
            self._init_whitening_stats()

            while True:
                data_dict = replay_buffer.sample(
                    self.args.transfer_sample_traj)
                episode_length = self.args.lookahead_increment

                training_info = {'episode_length': episode_length}

                stats, _ = self.models['transfer'].train(
                    data_dict, buffer, training_info)

                timer_dict['** Program Total Time **'] = time.time()
                log_results(stats, timer_dict)

                if self.current_iteration > self.args.transfer_iterations:
                    _log_path = logger._get_path()
                    _save_root = 'transfer'
                    _save_extension = _save_root + \
                      "_{}_{}.ckpt".format(
                          self._name_scope, self._timesteps_so_far
                      )

                    _save_dir = osp.join(_log_path, _save_extension)
                    self._saver.save(self._session, _save_dir)
                    break

                else:
                    self.current_iteration += 1
        return self._get_weights()
Beispiel #3
0
def train(trainer, sampler, worker, network_type, args=None):
    logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))

    # make the trainer and sampler
    sampler_agent = make_sampler(sampler, worker, network_type, args)
    trainer_tasks, trainer_results, trainer_agent, init_weights = \
        make_trainer(trainer, network_type, args)
    sampler_agent.set_weights(init_weights)

    timer_dict = OrderedDict()
    timer_dict['Program Start'] = time.time()
    current_iteration = 0

    while True:
        timer_dict['** Program Total Time **'] = time.time()

        # step 1: collect rollout data
        rollout_data = \
            sampler_agent._rollout_with_workers()

        timer_dict['Generate Rollout'] = time.time()

        # step 2: train the weights for dynamics and policy network
        training_info = {}

        if args.pretrain_vae and current_iteration < args.pretrain_iterations:
            training_info['train_net'] = 'vae'

        elif args.decoupled_managers:
            if (current_iteration % \
                (args.manager_updates + args.actor_updates)) \
                < args.manager_updates:
                training_info['train_net'] = 'manager'

            else:
                training_info['train_net'] = 'actor'

        trainer_tasks.put((parallel_util.TRAIN_SIGNAL, {
            'data': rollout_data['data'],
            'training_info': training_info
        }))
        trainer_tasks.join()
        training_return = trainer_results.get()
        timer_dict['Train Weights'] = time.time()

        # step 4: update the weights
        sampler_agent.set_weights(training_return['network_weights'])
        timer_dict['Assign Weights'] = time.time()

        # log and print the results
        log_results(training_return, timer_dict)

        #if totalsteps > args.max_timesteps:
        if training_return['totalsteps'] > args.max_timesteps:
            break
        else:
            current_iteration += 1

    # end of training
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))
def train(trainer, sampler, worker, models,
          args=None, pretrain_dict = None,
          environments_cache=None):

    logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))
    
    # make the trainer and sampler
    sampler_agent = make_sampler(sampler, worker, models, args)
    trainer_tasks, trainer_results, trainer_agent, init_weights = \
        make_trainer(trainer, models, args)

    if pretrain_dict is not None:
        pretrain_weights, environments_cache = \
            pretrain_dict['pretrain_fnc'](
                pretrain_dict['pretrain_thread'], models, args,
                environments_cache
            )

    else:
        pretrain_weights = environments_cache = None

    for key in pretrain_weights['base']:
        try:
            assert not np.array_equal(pretrain_weights['base'][key],
                                      init_weights['base'][key])
        except:
            print(key, pretrain_weights['base'][key], init_weights['base'][key])
        
    init_weights = init_weights \
        if pretrain_weights is None else pretrain_weights

    trainer_tasks.put(
       (parallel_util.TRAINER_SET_WEIGHTS,
       init_weights)
    )
    trainer_tasks.join()

    sampler_agent.set_weights(init_weights)
    if environments_cache is not None:
        sampler_agent.set_environments(environments_cache)

        trainer_tasks.put(
            (parallel_util.TRAINER_SET_ENVIRONMENTS,
            environments_cache)
        )

    timer_dict = OrderedDict()
    timer_dict['Program Start'] = time.time()
    current_iteration = 0

    while True:
        timer_dict['** Program Total Time **'] = time.time()

        training_info = {}
        rollout_info = {}

        training_info['train_model'] = 'final'
        rollout_info['rollout_model'] = 'final'
            
        if args.freeze_actor_final:
            training_info['train_net'] = 'manager'

        elif args.decoupled_managers:
            if (current_iteration % \
                (args.manager_updates + args.actor_updates)) \
                < args.manager_updates:
                training_info['train_net'] = 'manager'

            else:
                training_info['train_net'] = 'actor'

        else:
            training_info['train_net'] = None
            
        rollout_data = \
            sampler_agent._rollout_with_workers(rollout_info)

        timer_dict['Generate Rollout'] = time.time()

        trainer_tasks.put(
            (parallel_util.TRAIN_SIGNAL,
             {'data': rollout_data['data'], 'training_info': training_info})
        )
        trainer_tasks.join()
        training_return = trainer_results.get()
        timer_dict['Train Weights'] = time.time()

        # step 4: update the weights
        weights = training_return['network_weights']
        for key in weights['base']:
            assert np.array_equal(weights['base'][key],
                init_weights['base'][key])
        sampler_agent.set_weights(weights)
        timer_dict['Assign Weights'] = time.time()

        # log and print the results
        log_results(training_return, timer_dict)

        if training_return['totalsteps'] > args.max_timesteps:
            trainer_tasks.put(
                parallel_util.SAVE_SIGNAL,
                {'net': 'final'}
            )

        #if totalsteps > args.max_timesteps:
        if training_return['totalsteps'] > args.max_timesteps:
            break
        else:
            current_iteration += 1

    # end of training
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))
def train(trainer, sampler, worker, models, args=None, pretrain_weights=None):
    logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))

    # make the trainer and sampler
    sampler_agent = make_sampler(sampler, worker, models, args)
    trainer_tasks, trainer_results, trainer_agent, init_weights = \
        make_trainer(trainer, models, args)

    init_weights = init_weights \
        if pretrain_weights is None else pretrain_weights

    sampler_agent.set_weights(init_weights)

    timer_dict = OrderedDict()
    timer_dict['Program Start'] = time.time()
    current_iteration = 0

    while True:
        timer_dict['** Program Total Time **'] = time.time()

        training_info = {}
        rollout_info = {}

        if current_iteration < args.train_dqn_iterations:
            training_info['train_model'] = 'base'
            rollout_info['rollout_model'] = 'base'

        elif current_iteration < (args.train_dqn_iterations + \
            args.train_transfer_iterations):
            training_info['train_model'] = 'transfer'
            rollout_info['rollout_model'] = 'transfer'

        else:
            training_info['train_model'] = 'final'
            rollout_info['rollout_model'] = 'final'

            if args.freeze_actor_final:
                training_info['train_net'] = 'manager'

            elif args.decoupled_managers:
                if (current_iteration % \
                    (args.manager_updates + args.actor_updates)) \
                    < args.manager_updates:
                    training_info['train_net'] = 'manager'

                else:
                    training_info['train_net'] = 'actor'

            else:
                training_info['train_net'] = None

        rollout_data = \
            sampler_agent._rollout_with_workers(rollout_info)

        timer_dict['Generate Rollout'] = time.time()

        trainer_tasks.put((parallel_util.TRAIN_SIGNAL, {
            'data': rollout_data['data'],
            'training_info': training_info
        }))
        trainer_tasks.join()
        training_return = trainer_results.get()
        timer_dict['Train Weights'] = time.time()

        # step 4: update the weights
        sampler_agent.set_weights(training_return['network_weights'])
        timer_dict['Assign Weights'] = time.time()

        # log and print the results
        log_results(training_return, timer_dict)

        if current_iteration == args.train_dqn_iterations:
            trainer_tasks.put((parallel_util.SAVE_SIGNAL, {'net': 'base'}))

        elif current_iteration == \
            (args.train_dqn_iterations + args.train_transfer_iterations):
            trainer_tasks.put(parallel_util.SAVE_SIGNAL, {'net': 'transfer'})

        elif training_return['totalsteps'] > args.max_timesteps:
            trainer_tasks.put(parallel_util.SAVE_SIGNAL, {'net': 'final'})

        #if totalsteps > args.max_timesteps:
        if training_return['totalsteps'] > args.max_timesteps:
            break
        else:
            current_iteration += 1

    # end of training
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))