def inverse_loss(self, samples): observation = samples.observation[0] # [T,B,C,H,W]->[B,C,H,W] last_observation = samples.observation[-1] if self.random_shift_prob > 0.: observation = random_shift( imgs=observation, pad=self.random_shift_pad, prob=self.random_shift_prob, ) last_observation = random_shift( imgs=last_observation, pad=self.random_shift_pad, prob=self.random_shift_prob, ) action = samples.action # [T,B,A] # if self.onehot_actions: # action = to_onehot(action, self._act_dim, dtype=torch.float) observation, last_observation, action = buffer_to( (observation, last_observation, action), device=self.device) _, conv_obs = self.encoder(observation) _, conv_last = self.encoder(last_observation) valid = valid_from_done(samples.done).type(torch.bool) # [T,B] # All timesteps invalid if the last_observation is: valid = valid[-1].repeat(self.delta_T, 1).transpose(1, 0) # [B,T-1] if self.onehot_actions: logits = self.inverse_model(conv_obs, conv_last) # [B,T-1,A] labels = action[:-1].transpose(1, 0) # [B,T-1], not the last action labels[~valid] = IGNORE_INDEX b, t, a = logits.shape logits = logits.view(b * t, a) labels = labels.reshape(b * t) logits = logits - torch.max(logits, dim=1, keepdim=True)[0] inv_loss = self.c_e_loss(logits, labels) valid = valid.reshape(b * t).to(self.device) dist_info = DistInfo(prob=F.softmax(logits, dim=1)) entropy = self.distribution.mean_entropy( dist_info=dist_info, valid=valid, ) entropy_loss = -self.entropy_loss_coeff * entropy correct = torch.argmax(logits.detach(), dim=1) == labels accuracy = torch.mean(correct[valid].float()) else: raise NotImplementedError perplexity = self.distribution.mean_perplexity(dist_info, valid.to(self.device)) return inv_loss, entropy_loss, accuracy, perplexity, conv_obs
def ul_optimize_one_step(self, samples=None): self.ul_optimizer.zero_grad() if samples is None: if self.ul_pri_alpha > 0: samples = self.replay_buffer.sample_batch(self.ul_batch_size, mode="UL") else: samples = self.replay_buffer.sample_batch(self.ul_batch_size) # This is why need ul_delta_T == n_step_return, usually == 1; anchor = samples.agent_inputs.observation positive = samples.target_inputs.observation if self.ul_random_shift_prob > 0.0: anchor = random_shift( imgs=anchor, pad=self.ul_random_shift_pad, prob=self.ul_random_shift_prob, ) positive = random_shift( imgs=positive, pad=self.ul_random_shift_pad, prob=self.ul_random_shift_prob, ) anchor, positive = buffer_to((anchor, positive), device=self.agent.device) else: # Assume samples were already augmented in the RL loss. anchor = samples.agent_inputs.observation positive = samples.target_inputs.observation with torch.no_grad(): c_positive, _pos_conv = self.ul_target_encoder(positive) c_anchor, _anc_conv = self.ul_encoder(anchor) logits = self.ul_contrast(c_anchor, c_positive) # anchor mlp in here. labels = torch.arange(c_anchor.shape[0], dtype=torch.long, device=self.agent.device) invalid = samples.done # shape: [B], if done, following state invalid labels[invalid] = IGNORE_INDEX ul_loss = self.c_e_loss(logits, labels) ul_loss.backward() if self.ul_clip_grad_norm is None: grad_norm = 0.0 else: grad_norm = torch.nn.utils.clip_grad_norm_(self.ul_parameters(), self.ul_clip_grad_norm) self.ul_optimizer.step() correct = torch.argmax(logits.detach(), dim=1) == labels accuracy = torch.mean(correct[~invalid].float()) return ul_loss, accuracy, grad_norm
def ul_optimize_one_step(self): self.ul_optimizer.zero_grad() samples = self.replay_buffer.ul_sample_batch(self.ul_batch_B) anchor = samples.observation[:-self.ul_delta_T] positive = samples.observation[self.ul_delta_T:] t, b, c, h, w = anchor.shape anchor = anchor.reshape(t * b, c, h, w) positive = positive.reshape(t * b, c, h, w) if self.ul_random_shift_prob > 0.0: anchor = random_shift( imgs=anchor, pad=self.ul_random_shift_pad, prob=self.ul_random_shift_prob, ) positive = random_shift( imgs=positive, pad=self.ul_random_shift_pad, prob=self.ul_random_shift_prob, ) anchor, positive = buffer_to((anchor, positive), device=self.agent.device) with torch.no_grad(): c_positive, _pos_conv = self.ul_target_encoder(positive) c_anchor, _anc_conv = self.ul_encoder(anchor) logits = self.ul_contrast(c_anchor, c_positive) # anchor mlp in here. labels = torch.arange(c_anchor.shape[0], dtype=torch.long, device=self.agent.device) valid = valid_from_done(samples.done).type(torch.bool) # use all valid = valid[self.ul_delta_T:].reshape(-1) # at positions of positive labels[~valid] = IGNORE_INDEX ul_loss = self.c_e_loss(logits, labels) ul_loss.backward() if self.ul_clip_grad_norm is None: grad_norm = 0.0 else: grad_norm = torch.nn.utils.clip_grad_norm_(self.ul_parameters(), self.ul_clip_grad_norm) self.ul_optimizer.step() correct = torch.argmax(logits.detach(), dim=1) == labels accuracy = torch.mean(correct[valid].float()) return ul_loss, accuracy, grad_norm
def random_shift_rl_samples(self, samples): if self.random_shift_prob == 0.0: return samples obs = samples.agent_inputs.observation target_obs = samples.target_inputs.observation aug_obs = random_shift( imgs=obs, pad=self.random_shift_pad, prob=self.random_shift_prob, ) aug_target_obs = random_shift( imgs=target_obs, pad=self.random_shift_pad, prob=self.random_shift_prob, ) aug_samples = samples._replace( agent_inputs=samples.agent_inputs._replace(observation=aug_obs), target_inputs=samples.target_inputs._replace( observation=aug_target_obs), ) return aug_samples
def atc_loss(self, samples): anchor = (samples.observation if self.delta_T == 0 else samples.observation[:-self.delta_T]) positive = samples.observation[self.delta_T:] t, b, c, h, w = anchor.shape anchor = anchor.view(t * b, c, h, w) # Treat all T,B as separate. positive = positive.view(t * b, c, h, w) if self.random_shift_prob > 0.: anchor = random_shift( imgs=anchor, pad=self.random_shift_pad, prob=self.random_shift_prob, ) positive = random_shift( imgs=positive, pad=self.random_shift_pad, prob=self.random_shift_prob, ) anchor, positive = buffer_to((anchor, positive), device=self.device) with torch.no_grad(): c_positive, _ = self.target_encoder(positive) c_anchor, conv_output = self.encoder(anchor) logits = self.contrast(anchor=c_anchor, positive=c_positive) labels = torch.arange(c_anchor.shape[0], dtype=torch.long, device=self.device) valid = valid_from_done(samples.done).type(torch.bool) valid = valid[self.delta_T:].reshape(-1) # at location of positive labels[~valid] = IGNORE_INDEX atc_loss = self.c_e_loss(logits, labels) correct = torch.argmax(logits.detach(), dim=1) == labels accuracy = torch.mean(correct[valid].float()) return atc_loss, accuracy, conv_output
def ats_loss(self, samples): anchor = (samples.observation if self.delta_T == 0 else samples.observation[:-self.delta_T]) positive = samples.observation[self.delta_T:] t, b, c, h, w = anchor.shape anchor = anchor.view(t * b, c, h, w) # Treat all T,B as separate. positive = positive.view(t * b, c, h, w) if self.random_shift_prob > 0.: anchor = random_shift( imgs=anchor, pad=self.random_shift_pad, prob=self.random_shift_prob, ) positive = random_shift( imgs=positive, pad=self.random_shift_pad, prob=self.random_shift_prob, ) anchor, positive = buffer_to((anchor, positive), device=self.device) with torch.no_grad(): z_positive, _ = self.target_encoder(positive) z_anchor, conv_output = self.encoder(anchor) q_anchor = self.predictor(z_anchor) q = F.normalize(q_anchor, dim=-1, p=2) z = F.normalize(z_positive, dim=-1, p=2) ats_losses = 2. - 2 * (q * z).sum(dim=-1) # from BYOL valid = valid_from_done(samples.done.type(torch.bool)) valid = valid[self.delta_T:].reshape(-1) valid = valid.to(self.device) ats_loss = valid_mean(ats_losses, valid) return ats_loss, conv_output
def data_aug_loss_samples(self, samples): """Perform data augmentation (on CPU).""" if self.augmentation is None: return samples obs = samples.agent_inputs.observation target_obs = samples.target_inputs.observation if self.augmentation == "random_shift": aug_obs = random_shift( imgs=obs, pad=self.random_shift_pad, prob=self.random_shift_prob, ) aug_target_obs = random_shift( imgs=target_obs, pad=self.random_shift_pad, prob=self.random_shift_prob, ) elif self.augmentation == "subpixel_shift": aug_obs = subpixel_shift( imgs=obs, max_shift=self.max_pixel_shift, ) aug_target_obs = subpixel_shift( imgs=target_obs, max_shift=self.max_pixel_shift, ) else: raise NotImplementedError aug_samples = samples._replace( agent_inputs=samples.agent_inputs._replace(observation=aug_obs), target_inputs=samples.target_inputs._replace(observation=aug_target_obs), ) return aug_samples