Exemplo n.º 1
0
 def generate_data(self, virtual_envs: VecVirtualEnv, policy_buffer: Buffer,
                   initial_states: torch.Tensor, actor):
     states = initial_states
     batch_size = initial_states.shape[0]
     num_total_samples = 0
     for step in range(self.num_rollout_steps):
         with torch.no_grad():
             actions = actor.act(states)['actions']
         next_states, rewards, dones, _ = virtual_envs.step_with_states(
             states, actions)
         masks = torch.tensor([[0.0] if done else [1.0] for done in dones],
                              dtype=torch.float32)
         policy_buffer.insert(states=states,
                              actions=actions,
                              masks=masks,
                              rewards=rewards,
                              next_states=next_states)
         num_total_samples += next_states.shape[0]
         # states which are not done
         states = next_states[torch.where(torch.gt(masks, 0.5))[0], :]
         if states.shape[0] == 0:
             logger.warn(
                 '[ Model Rollout ] Breaking early: {}'.format(step))
             break
     if self.verbose:
         logger.log(
             '[ Model Rollout ] {} samples with average rollout length {:.2f}'
             .format(num_total_samples, num_total_samples / batch_size))
Exemplo n.º 2
0
 def update_rollout_length(self, epoch: int):
     min_epoch, max_epoch, min_length, max_length = self.rollout_schedule
     if epoch <= min_epoch:
         y = min_length
     else:
         dx = (epoch - min_epoch) / (max_epoch - min_epoch)
         dx = min(dx, 1)
         y = dx * (max_length - min_length) + min_length
     y = int(y)
     if self.verbose > 0 and self.num_rollout_steps != y:
         logger.log('[ Model Rollout ] Max rollout length {} -> {} '.format(
             self.num_rollout_steps, y))
     self.num_rollout_steps = y
Exemplo n.º 3
0
 def get_batch_generator_inf(self, batch_size, **kwargs):
     batch_sizes = (batch_size * self.weights).astype(np.int)
     if self.verbose:
         logger.log('[Buffer Mixing] Max error {}'.format(
             np.max(np.abs(batch_sizes / batch_size - self.weights))))
     rand_index = np.random.randint(len(batch_sizes))
     batch_sizes[rand_index] = batch_size - np.delete(
         batch_sizes, rand_index).sum()
     inf_gens = [
         buffer.get_batch_generator_inf(int(batch_size_), **kwargs)
         for buffer, batch_size_ in zip(self.buffers, batch_sizes)
     ]
     while True:
         buffer_samples = list(map(lambda gen: next(gen), inf_gens))
         yield merge_dicts(buffer_samples, lambda x: torch.cat(x, dim=0))
Exemplo n.º 4
0
def init_logging(config, hparam_dict):
    import datetime
    current_time = datetime.datetime.now().strftime('%b%d_%H%M%S')
    log_dir = os.path.join(config.proj_dir, config.result_dir, current_time, 'log')
    eval_log_dir = os.path.join(config.proj_dir, config.result_dir, current_time, 'log_eval')
    save_dir = os.path.join(config.proj_dir, config.result_dir, current_time, 'save')
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(eval_log_dir, exist_ok=True)
    os.makedirs(save_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)
    writer.add_hparams(hparam_dict, metric_dict={})

    logger.configure(log_dir, None, config.log_email, config.proj_name)
    logger.info('Hyperparms:')
    for key, value in hparam_dict.items():
        logger.log('{:35s}: {}'.format(key, value))

    return writer, log_dir, eval_log_dir, save_dir
Exemplo n.º 5
0
    def update(self, model_buffer: Buffer) -> Dict[str, float]:
        model_loss_epoch = 0.
        l2_loss_epoch = 0.

        if self.max_num_epochs:
            epoch_iter = range(self.max_num_epochs)
        else:
            epoch_iter = count()

        train_indices, val_indices = split_model_buffer(
            model_buffer, self.training_ratio)

        num_epoch_after_update = 0
        num_updates = 0
        epoch = 0

        self.dynamics.reset_best_snapshots()

        for epoch in epoch_iter:
            train_gen = model_buffer.get_ensemble_batch_generator_epoch(
                self.batch_size, train_indices)
            val_gen = model_buffer.get_batch_generator_epoch(None, val_indices)

            for samples in train_gen:
                train_model_loss, train_l2_loss = self.compute_loss(
                    samples, True, True, True)
                train_model_loss, train_l2_loss = train_model_loss.sum(
                ), train_l2_loss.sum()
                train_model_loss += \
                    0.01 * (torch.sum(self.dynamics.max_diff_state_logvar) + torch.sum(self.dynamics.max_reward_logvar) -
                            torch.sum(self.dynamics.min_diff_state_logvar) - torch.sum(self.dynamics.min_reward_logvar))

                model_loss_epoch += train_model_loss.item()
                l2_loss_epoch += train_l2_loss.item()

                self.dynamics_optimizer.zero_grad()
                (train_l2_loss + train_model_loss).backward()
                self.dynamics_optimizer.step()

                num_updates += 1

            with torch.no_grad():
                val_model_loss, _ = self.compute_loss(next(val_gen), False,
                                                      False, False)
            updated = self.dynamics.update_best_snapshots(
                val_model_loss, epoch)

            # updated == True, means training is useful.
            if updated:
                num_epoch_after_update = 0
            else:
                num_epoch_after_update += 1
            # if training is useless for 5 epoch, stop training.
            if num_epoch_after_update > 5:
                break

        model_loss_epoch /= num_updates
        l2_loss_epoch /= num_updates

        val_gen = model_buffer.get_batch_generator_epoch(None, val_indices)
        # load best snapshots, which is evaluated by validation set.
        best_epochs = self.dynamics.load_best_snapshots()
        with torch.no_grad():
            val_model_loss, _ = self.compute_loss(next(val_gen), False, False,
                                                  False)
        self.dynamics.update_elite_indices(val_model_loss)

        if self.verbose > 0:
            logger.log('[ Model Training ] Converge at epoch {}'.format(epoch))
            logger.log(
                '[ Model Training ] Load best state_dict from epoch {}'.format(
                    best_epochs))
            logger.log(
                '[ Model Training ] Validation Model loss of elite networks: {}'
                .format(
                    val_model_loss.cpu().numpy()[self.dynamics.elite_indices]))

        return {'model_loss': model_loss_epoch, 'l2_loss': l2_loss_epoch}