Exemplo n.º 1
0
    def inverse_loss(self, samples):
        observation = samples.observation[0]  # [T,B,C,H,W]->[B,C,H,W]
        last_observation = samples.observation[-1]

        if self.random_shift_prob > 0.:
            observation = random_shift(
                imgs=observation,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )
            last_observation = random_shift(
                imgs=last_observation,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )

        action = samples.action  # [T,B,A]
        # if self.onehot_actions:
        #     action = to_onehot(action, self._act_dim, dtype=torch.float)
        observation, last_observation, action = buffer_to(
            (observation, last_observation, action), device=self.device)

        _, conv_obs = self.encoder(observation)
        _, conv_last = self.encoder(last_observation)

        valid = valid_from_done(samples.done).type(torch.bool)  # [T,B]
        # All timesteps invalid if the last_observation is:
        valid = valid[-1].repeat(self.delta_T, 1).transpose(1, 0)  # [B,T-1]

        if self.onehot_actions:
            logits = self.inverse_model(conv_obs, conv_last)  # [B,T-1,A]
            labels = action[:-1].transpose(1,
                                           0)  # [B,T-1], not the last action
            labels[~valid] = IGNORE_INDEX

            b, t, a = logits.shape
            logits = logits.view(b * t, a)
            labels = labels.reshape(b * t)
            logits = logits - torch.max(logits, dim=1, keepdim=True)[0]
            inv_loss = self.c_e_loss(logits, labels)

            valid = valid.reshape(b * t).to(self.device)
            dist_info = DistInfo(prob=F.softmax(logits, dim=1))
            entropy = self.distribution.mean_entropy(
                dist_info=dist_info,
                valid=valid,
            )
            entropy_loss = -self.entropy_loss_coeff * entropy

            correct = torch.argmax(logits.detach(), dim=1) == labels
            accuracy = torch.mean(correct[valid].float())

        else:
            raise NotImplementedError

        perplexity = self.distribution.mean_perplexity(dist_info,
                                                       valid.to(self.device))

        return inv_loss, entropy_loss, accuracy, perplexity, conv_obs
Exemplo n.º 2
0
    def ul_optimize_one_step(self, samples=None):
        self.ul_optimizer.zero_grad()
        if samples is None:
            if self.ul_pri_alpha > 0:
                samples = self.replay_buffer.sample_batch(self.ul_batch_size,
                                                          mode="UL")
            else:
                samples = self.replay_buffer.sample_batch(self.ul_batch_size)

            # This is why need ul_delta_T == n_step_return, usually == 1;
            anchor = samples.agent_inputs.observation
            positive = samples.target_inputs.observation

            if self.ul_random_shift_prob > 0.0:
                anchor = random_shift(
                    imgs=anchor,
                    pad=self.ul_random_shift_pad,
                    prob=self.ul_random_shift_prob,
                )
                positive = random_shift(
                    imgs=positive,
                    pad=self.ul_random_shift_pad,
                    prob=self.ul_random_shift_prob,
                )

            anchor, positive = buffer_to((anchor, positive),
                                         device=self.agent.device)

        else:
            # Assume samples were already augmented in the RL loss.
            anchor = samples.agent_inputs.observation
            positive = samples.target_inputs.observation

        with torch.no_grad():
            c_positive, _pos_conv = self.ul_target_encoder(positive)
        c_anchor, _anc_conv = self.ul_encoder(anchor)
        logits = self.ul_contrast(c_anchor, c_positive)  # anchor mlp in here.

        labels = torch.arange(c_anchor.shape[0],
                              dtype=torch.long,
                              device=self.agent.device)
        invalid = samples.done  # shape: [B], if done, following state invalid
        labels[invalid] = IGNORE_INDEX
        ul_loss = self.c_e_loss(logits, labels)
        ul_loss.backward()
        if self.ul_clip_grad_norm is None:
            grad_norm = 0.0
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.ul_parameters(),
                                                       self.ul_clip_grad_norm)
        self.ul_optimizer.step()

        correct = torch.argmax(logits.detach(), dim=1) == labels
        accuracy = torch.mean(correct[~invalid].float())

        return ul_loss, accuracy, grad_norm
Exemplo n.º 3
0
    def ul_optimize_one_step(self):
        self.ul_optimizer.zero_grad()
        samples = self.replay_buffer.ul_sample_batch(self.ul_batch_B)

        anchor = samples.observation[:-self.ul_delta_T]
        positive = samples.observation[self.ul_delta_T:]
        t, b, c, h, w = anchor.shape
        anchor = anchor.reshape(t * b, c, h, w)
        positive = positive.reshape(t * b, c, h, w)

        if self.ul_random_shift_prob > 0.0:
            anchor = random_shift(
                imgs=anchor,
                pad=self.ul_random_shift_pad,
                prob=self.ul_random_shift_prob,
            )
            positive = random_shift(
                imgs=positive,
                pad=self.ul_random_shift_pad,
                prob=self.ul_random_shift_prob,
            )

        anchor, positive = buffer_to((anchor, positive),
                                     device=self.agent.device)

        with torch.no_grad():
            c_positive, _pos_conv = self.ul_target_encoder(positive)
        c_anchor, _anc_conv = self.ul_encoder(anchor)
        logits = self.ul_contrast(c_anchor, c_positive)  # anchor mlp in here.

        labels = torch.arange(c_anchor.shape[0],
                              dtype=torch.long,
                              device=self.agent.device)
        valid = valid_from_done(samples.done).type(torch.bool)  # use all
        valid = valid[self.ul_delta_T:].reshape(-1)  # at positions of positive
        labels[~valid] = IGNORE_INDEX

        ul_loss = self.c_e_loss(logits, labels)
        ul_loss.backward()
        if self.ul_clip_grad_norm is None:
            grad_norm = 0.0
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.ul_parameters(),
                                                       self.ul_clip_grad_norm)
        self.ul_optimizer.step()

        correct = torch.argmax(logits.detach(), dim=1) == labels
        accuracy = torch.mean(correct[valid].float())

        return ul_loss, accuracy, grad_norm
Exemplo n.º 4
0
 def random_shift_rl_samples(self, samples):
     if self.random_shift_prob == 0.0:
         return samples
     obs = samples.agent_inputs.observation
     target_obs = samples.target_inputs.observation
     aug_obs = random_shift(
         imgs=obs,
         pad=self.random_shift_pad,
         prob=self.random_shift_prob,
     )
     aug_target_obs = random_shift(
         imgs=target_obs,
         pad=self.random_shift_pad,
         prob=self.random_shift_prob,
     )
     aug_samples = samples._replace(
         agent_inputs=samples.agent_inputs._replace(observation=aug_obs),
         target_inputs=samples.target_inputs._replace(
             observation=aug_target_obs),
     )
     return aug_samples
Exemplo n.º 5
0
    def atc_loss(self, samples):
        anchor = (samples.observation if self.delta_T == 0 else
                  samples.observation[:-self.delta_T])
        positive = samples.observation[self.delta_T:]
        t, b, c, h, w = anchor.shape
        anchor = anchor.view(t * b, c, h, w)  # Treat all T,B as separate.
        positive = positive.view(t * b, c, h, w)

        if self.random_shift_prob > 0.:
            anchor = random_shift(
                imgs=anchor,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )
            positive = random_shift(
                imgs=positive,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )

        anchor, positive = buffer_to((anchor, positive), device=self.device)

        with torch.no_grad():
            c_positive, _ = self.target_encoder(positive)
        c_anchor, conv_output = self.encoder(anchor)

        logits = self.contrast(anchor=c_anchor, positive=c_positive)
        labels = torch.arange(c_anchor.shape[0],
                              dtype=torch.long,
                              device=self.device)
        valid = valid_from_done(samples.done).type(torch.bool)
        valid = valid[self.delta_T:].reshape(-1)  # at location of positive
        labels[~valid] = IGNORE_INDEX
        atc_loss = self.c_e_loss(logits, labels)

        correct = torch.argmax(logits.detach(), dim=1) == labels
        accuracy = torch.mean(correct[valid].float())

        return atc_loss, accuracy, conv_output
    def ats_loss(self, samples):
        anchor = (samples.observation if self.delta_T == 0 else
                  samples.observation[:-self.delta_T])
        positive = samples.observation[self.delta_T:]
        t, b, c, h, w = anchor.shape
        anchor = anchor.view(t * b, c, h, w)  # Treat all T,B as separate.
        positive = positive.view(t * b, c, h, w)

        if self.random_shift_prob > 0.:
            anchor = random_shift(
                imgs=anchor,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )
            positive = random_shift(
                imgs=positive,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )

        anchor, positive = buffer_to((anchor, positive), device=self.device)

        with torch.no_grad():
            z_positive, _ = self.target_encoder(positive)
        z_anchor, conv_output = self.encoder(anchor)
        q_anchor = self.predictor(z_anchor)

        q = F.normalize(q_anchor, dim=-1, p=2)
        z = F.normalize(z_positive, dim=-1, p=2)
        ats_losses = 2. - 2 * (q * z).sum(dim=-1)  # from BYOL

        valid = valid_from_done(samples.done.type(torch.bool))
        valid = valid[self.delta_T:].reshape(-1)
        valid = valid.to(self.device)
        ats_loss = valid_mean(ats_losses, valid)

        return ats_loss, conv_output
Exemplo n.º 7
0
    def data_aug_loss_samples(self, samples):
        """Perform data augmentation (on CPU)."""
        if self.augmentation is None:
            return samples

        obs = samples.agent_inputs.observation
        target_obs = samples.target_inputs.observation

        if self.augmentation == "random_shift":
            aug_obs = random_shift(
                imgs=obs,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )
            aug_target_obs = random_shift(
                imgs=target_obs,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )
        elif self.augmentation == "subpixel_shift":
            aug_obs = subpixel_shift(
                imgs=obs,
                max_shift=self.max_pixel_shift,
            )
            aug_target_obs = subpixel_shift(
                imgs=target_obs,
                max_shift=self.max_pixel_shift,
            )
        else:
            raise NotImplementedError

        aug_samples = samples._replace(
            agent_inputs=samples.agent_inputs._replace(observation=aug_obs),
            target_inputs=samples.target_inputs._replace(observation=aug_target_obs),
        )

        return aug_samples