示例#1
0
    def forward(self, inputs, length, initial_inputs=None, static_inputs=None):
        """
        
        :param inputs: These are sliced by time. Time is the second dimension
        :param length: Rollout length
        :param initial_inputs: These are not sliced and are overridden by cell output
        :param static_inputs: These are not sliced and can't be overridden by cell output
        :return:
        """
        # NOTE! Unrolling the cell directly will result in crash as the hidden state is not being reset
        # Use this function or CustomLSTMCell.unroll if needed
        initial_inputs, static_inputs = self.assert_begin(inputs, initial_inputs, static_inputs)

        step_inputs = initial_inputs.copy()
        step_inputs.update(static_inputs)
        lstm_outputs = []
        for t in range(length):
            step_inputs.update(map_dict(lambda x: x[:, t], inputs))  # Slicing
            output = self.cell(**step_inputs)
            
            self.assert_post(output, inputs, initial_inputs, static_inputs)
            # TODO Test what signature does with *args
            autoregressive_output = subdict(output, output.keys() & signature(self.cell.forward).parameters)
            step_inputs.update(autoregressive_output)
            lstm_outputs.append(output)
        
        # TODO recursively stack outputs
        lstm_outputs = listdict2dictlist(lstm_outputs)
        lstm_outputs = map_dict(lambda x: stack(x, dim=1), lstm_outputs)
            
        self.cell.reset()
        return lstm_outputs
示例#2
0
def num_parameters(model, level=0):
    """  Returns the number of parameters used in a module.
    
    Known bug: if some of the submodules are repeated, their parameters will be double counted
    :param model:
    :param level: if level==1, returns a dictionary of submodule names and corresponding parameter counts
    :return:
    """
    
    if level == 0:
        return sum([p.numel() for p in model.parameters()])
    elif level == 1:
        return map_dict(num_parameters, dict(model.named_children()))
示例#3
0
    def _sample_max_len_video(self, data_dict, end_ind, target_len):
        """ This function processes data tensors so as to have length equal to target_len
        by sampling / padding if necessary """
        extra_length = (end_ind + 1) - target_len
        if self.phase == 'train':
            offset = max(0, int(np.random.rand() * (extra_length + 1)))
        else:
            offset = 0

        data_dict = map_dict(lambda tensor: self._maybe_pad(tensor, offset, target_len), data_dict)
        if 'actions' in data_dict:
            data_dict.actions = data_dict.actions[:-1]
        end_ind = min(end_ind - offset, target_len - 1)

        return end_ind, data_dict
示例#4
0
文件: train.py 项目: clvrai/spirl
    def val(self):
        print('Running Testing')
        if self.args.test_prediction:
            start = time.time()
            self.model_test.load_state_dict(self.model.state_dict())
            losses_meter = RecursiveAverageMeter()
            self.model_test.eval()
            self.evaluator.reset()
            with autograd.no_grad():
                for batch_idx, sample_batched in enumerate(self.val_loader):
                    inputs = AttrDict(
                        map_dict(lambda x: x.to(self.device), sample_batched))

                    # run evaluator with val-mode model
                    with self.model_test.val_mode():
                        self.evaluator.eval(inputs, self.model_test)

                    # run non-val-mode model (inference) to check overfitting
                    output = self.model_test(inputs)
                    losses = self.model_test.loss(output, inputs)

                    losses_meter.update(losses)
                    del losses

                if not self.args.dont_save:
                    if self.evaluator is not None:
                        self.evaluator.dump_results(self.global_step)

                    self.model_test.log_outputs(output,
                                                inputs,
                                                losses_meter.avg,
                                                self.global_step,
                                                log_images=True,
                                                phase='val',
                                                **self._logging_kwargs)
                    print((
                        '\nTest set: Average loss: {:.4f} in {:.2f}s\n'.format(
                            losses_meter.avg.total.value.item(),
                            time.time() - start)))
            del output
示例#5
0
文件: agent.py 项目: xiaofei-w/spirl
 def _remove_batch(d):
     """Adds batch dimension to all tensors in d."""
     return map_dict(
         lambda x: x[0] if
         (isinstance(x, torch.Tensor) or isinstance(x, np.ndarray)) else x,
         d)
示例#6
0
文件: ac_agent.py 项目: clvrai/spirl
    def update(self, experience_batch):
        """Updates actor and critics."""
        # push experience batch into replay buffer
        self.add_experience(experience_batch)

        for _ in range(self._hp.update_iterations):
            # sample batch and normalize
            experience_batch = self._sample_experience()
            experience_batch = self._normalize_batch(experience_batch)
            experience_batch = map2torch(experience_batch, self._hp.device)
            experience_batch = self._preprocess_experience(experience_batch)

            policy_output = self._run_policy(experience_batch.observation)

            # update alpha
            alpha_loss = self._update_alpha(experience_batch, policy_output)

            # compute policy loss
            policy_loss = self._compute_policy_loss(experience_batch,
                                                    policy_output)

            # compute target Q value
            with torch.no_grad():
                policy_output_next = self._run_policy(
                    experience_batch.observation_next)
                value_next = self._compute_next_value(experience_batch,
                                                      policy_output_next)
                q_target = experience_batch.reward * self._hp.reward_scale + \
                                (1 - experience_batch.done) * self._hp.discount_factor * value_next
                if self._hp.clip_q_target:
                    q_target = self._clip_q_target(q_target)
                q_target = q_target.detach()
                check_shape(q_target, [self._hp.batch_size])

            # compute critic loss
            critic_losses, qs = self._compute_critic_loss(
                experience_batch, q_target)

            # update critic networks
            [
                self._perform_update(critic_loss, critic_opt, critic)
                for critic_loss, critic_opt, critic in zip(
                    critic_losses, self.critic_opts, self.critics)
            ]

            # update target networks
            [
                self._soft_update_target_network(critic_target, critic) for
                critic_target, critic in zip(self.critic_targets, self.critics)
            ]

            # update policy network on policy loss
            self._perform_update(policy_loss, self.policy_opt, self.policy)

            # logging
            info = AttrDict(  # losses
                policy_loss=policy_loss,
                alpha_loss=alpha_loss,
                critic_loss_1=critic_losses[0],
                critic_loss_2=critic_losses[1],
            )
            if self._update_steps % 100 == 0:
                info.update(
                    AttrDict(  # gradient norms
                        policy_grad_norm=avg_grad_norm(self.policy),
                        critic_1_grad_norm=avg_grad_norm(self.critics[0]),
                        critic_2_grad_norm=avg_grad_norm(self.critics[1]),
                    ))
            info.update(
                AttrDict(  # misc
                    alpha=self.alpha,
                    pi_log_prob=policy_output.log_prob.mean(),
                    policy_entropy=policy_output.dist.entropy().mean(),
                    q_target=q_target.mean(),
                    q_1=qs[0].mean(),
                    q_2=qs[1].mean(),
                ))
            info.update(self._aux_info(experience_batch, policy_output))
            info = map_dict(ten2ar, info)

            self._update_steps += 1

        return info
示例#7
0
文件: train.py 项目: clvrai/spirl
    def train_epoch(self, epoch):
        self.model.train()
        epoch_len = len(self.train_loader)
        end = time.time()
        batch_time = AverageMeter()
        upto_log_time = AverageMeter()
        data_load_time = AverageMeter()
        self.log_outputs_interval = self.args.log_interval
        self.log_images_interval = int(epoch_len /
                                       self.args.per_epoch_img_logs)

        print('starting epoch ', epoch)

        for self.batch_idx, sample_batched in enumerate(self.train_loader):
            data_load_time.update(time.time() - end)
            inputs = AttrDict(
                map_dict(lambda x: x.to(self.device), sample_batched))
            with self.training_context():
                self.optimizer.zero_grad()
                output = self.model(inputs)
                losses = self.model.loss(output, inputs)
                losses.total.value.backward()
                self.call_hooks(inputs, output, losses, epoch)

                self.optimizer.step()
                self.model.step()

            if self.args.train_loop_pdb:
                import pdb
                pdb.set_trace()

            upto_log_time.update(time.time() - end)
            if self.log_outputs_now and not self.args.dont_save:
                self.model.log_outputs(output,
                                       inputs,
                                       losses,
                                       self.global_step,
                                       log_images=self.log_images_now,
                                       phase='train',
                                       **self._logging_kwargs)
            batch_time.update(time.time() - end)
            end = time.time()

            if self.log_outputs_now:
                print('GPU {}: {}'.format(
                    os.environ["CUDA_VISIBLE_DEVICES"]
                    if self.use_cuda else 'none', self._hp.exp_path))
                print(
                    ('itr: {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.
                     format(self.global_step, epoch, self.batch_idx,
                            len(self.train_loader),
                            100. * self.batch_idx / len(self.train_loader),
                            losses.total.value.item())))

                print(
                    'avg time for loading: {:.2f}s, logs: {:.2f}s, compute: {:.2f}s, total: {:.2f}s'
                    .format(data_load_time.avg,
                            batch_time.avg - upto_log_time.avg,
                            upto_log_time.avg - data_load_time.avg,
                            batch_time.avg))
                togo_train_time = batch_time.avg * (self._hp.num_epochs -
                                                    epoch) * epoch_len / 3600.
                print('ETA: {:.2f}h'.format(togo_train_time))

            del output, losses
            self.global_step = self.global_step + 1