def forward(self, states, actions, labels_dict): self.log.reset() assert actions.size(1)+1 == states.size(1) # final state has no corresponding action states = states.transpose(0,1) actions = actions.transpose(0,1) labels = torch.cat(list(labels_dict.values()), dim=-1) # Encode posterior = self.encode(states[:-1], actions=actions, labels=labels) kld = Normal.kl_divergence(posterior, free_bits=0.0).detach() self.log.metrics['kl_div_true'] = torch.sum(kld) kld = Normal.kl_divergence(posterior, free_bits=1/self.config['z_dim']) self.log.losses['kl_div'] = torch.sum(kld) # Decode self.reset_policy(labels=labels, z=posterior.sample()) for t in range(actions.size(0)): action_likelihood = self.decode_action(states[t]) self.log.losses['nll'] -= action_likelihood.log_prob(actions[t]) if self.is_recurrent: self.update_hidden(states[t], actions[t]) return self.log
def forward(self, states, actions, labels_dict, env): self.log.reset() assert actions.size(1)+1 == states.size(1) # final state has no corresponding action states = states.transpose(0,1) actions = actions.transpose(0,1) labels = torch.cat(list(labels_dict.values()), dim=-1) # Pretrain dynamics model if self.stage == 1: self.compute_dynamics_loss(states, actions) # Train CTVAE elif self.stage >= 2: # Encode posterior = self.encode(states[:-1], actions=actions, labels=labels) kld = Normal.kl_divergence(posterior, free_bits=0.0).detach() self.log.metrics['kl_div_true'] = torch.sum(kld) kld = Normal.kl_divergence(posterior, free_bits=1/self.config['z_dim']) self.log.losses['kl_div'] = torch.sum(kld) # Decode self.reset_policy(labels=labels, z=posterior.sample()) for t in range(actions.size(0)): action_likelihood = self.decode_action(states[t]) self.log.losses['nll'] -= action_likelihood.log_prob(actions[t]) if self.is_recurrent: self.update_hidden(states[t], actions[t]) # Generate rollout w/ dynamics model self.reset_policy(labels=labels) rollout_states, rollout_actions = self.generate_rollout_with_dynamics(states, horizon=actions.size(0)) # Maximize mutual information between rollouts and labels for lf_idx, lf_name in enumerate(labels_dict): lf = self.config['label_functions'][lf_idx] lf_labels = labels_dict[lf_name] auxiliary = self.discrim_labels(states[:-1], rollout_actions, lf_idx, lf.categorical) self.log.losses['{}_mi'.format(lf_name)] = -auxiliary.log_prob(lf_labels) # Update dynamics model with n_collect rollouts from environment if self.config['n_collect'] > 0: self.reset_policy(labels=labels[:1]) rollout_states_env, rollout_actions_env = self.generate_rollout_with_env(env, horizon=actions.size(0)) self.compute_dynamics_loss(rollout_states_env.to(labels.device), rollout_actions_env.to(labels.device)) return self.log
def forward(self, states, actions, labels_dict): self.log.reset() # Consistency and decoding loss need labels. if (self.loss_params['consistency_loss_weight'] > 0 or self.loss_params['decoding_loss_weight'] > 0): assert len(labels_dict) > 0 assert actions.size(1) + 1 == states.size( 1) # final state has no corresponding action states = states.transpose(0, 1) actions = actions.transpose(0, 1) labels = None if len(labels_dict) > 0: labels = torch.cat(list(labels_dict.values()), dim=-1) # Pretrain program approximators, if using consistency loss. if self.stage == 1 and self.loss_params['consistency_loss_weight'] > 0: for lf_idx, lf_name in enumerate(labels_dict): lf = self.config['label_functions'][lf_idx] lf_labels = labels_dict[lf_name] self.log.losses[lf_name] = compute_label_loss( states[:-1], actions, lf_labels, self.label_approx_birnn[lf_idx], self.label_approx_fc[lf_idx], lf.categorical) # Compute label loss with approx if self.log_metrics: approx_labels = self.label(states[:-1], actions, lf_idx, lf.categorical) assert approx_labels.size() == lf_labels.size() self.log.metrics['{}_approx'.format(lf.name)] = torch.sum( approx_labels * lf_labels) # Train TVAE with programs. elif self.stage >= 2 or not self.loss_params[ 'consistency_loss_weight'] > 0: # Encode posterior = self.encode(states[:-1], actions=actions, labels=labels) kld = Normal.kl_divergence(posterior, free_bits=0.0).detach() self.log.metrics['kl_div_true'] = torch.sum(kld) kld = Normal.kl_divergence(posterior, free_bits=1 / self.config['z_dim']) self.log.losses['kl_div'] = torch.sum(kld) # Decode self.reset_policy(labels=labels, z=posterior.sample()) for t in range(actions.size(0)): action_likelihood = self.decode_action(states[t]) self.log.losses['nll'] -= action_likelihood.log_prob( actions[t]) if self.is_recurrent: self.update_hidden(states[t], actions[t]) # Add decoding loss. if self.loss_params['decoding_loss_weight'] > 0: # Compute label loss for lf_idx, lf_name in enumerate(labels_dict): lf = self.config['label_functions'][lf_idx] lf_labels = labels_dict[lf_name] self.log.losses["decoded_" + lf_name] = compute_decoding_loss( posterior.mean, lf_labels, self.label_decoder_fc_decoding[lf_idx], lf.categorical, loss_weight=self. loss_params['decoding_loss_weight']) # Generate rollout for consistency loss. # Use the posterior to train here. if self.loss_params['consistency_loss_weight'] > 0: self.reset_policy( labels=labels, z=posterior.sample(), temperature=self.loss_params['consistency_temperature']) rollout_states, rollout_actions = self.generate_rollout( states, horizon=actions.size(0)) # Compute label loss for lf_idx, lf_name in enumerate(labels_dict): lf = self.config['label_functions'][lf_idx] lf_labels = labels_dict[lf_name] self.log.losses[lf_name] = compute_label_loss( rollout_states[:-1], rollout_actions, lf_labels, self.label_approx_birnn[lf_idx], self.label_approx_fc[lf_idx], lf.categorical, loss_weight=self.loss_params['consistency_loss_weight'] ) # Compute label loss with approx if self.log_metrics: approx_labels = self.label(rollout_states[:-1], rollout_actions, lf_idx, lf.categorical) assert approx_labels.size() == lf_labels.size() self.log.metrics['{}_approx'.format( lf.name)] = torch.sum(approx_labels * lf_labels) # Compute label loss with true LF rollout_lf_labels = lf.label( rollout_states.transpose(0, 1).detach().cpu(), rollout_actions.transpose(0, 1).detach().cpu(), batch=True) assert rollout_lf_labels.size() == lf_labels.size() self.log.metrics['{}_true'.format( lf.name)] = torch.sum(rollout_lf_labels * lf_labels.cpu()) # If augmentations are provided, additionally train with those. if 'augmentations' in self.config.keys(): for aug in self.config['augmentations']: augmented_states, augmented_actions = aug.augment( states.transpose(0, 1), actions.transpose(0, 1), batch=True) augmented_states = augmented_states.transpose(0, 1) augmented_actions = augmented_actions.transpose(0, 1) aug_posterior = self.encode(augmented_states[:-1], actions=augmented_actions, labels=labels) kld = Normal.kl_divergence(aug_posterior, free_bits=0.0).detach() self.log.metrics['{}_kl_div_true'.format( aug.name)] = torch.sum(kld) kld = Normal.kl_divergence(aug_posterior, free_bits=1 / self.config['z_dim']) self.log.losses['{}_kl_div'.format( aug.name)] = torch.sum(kld) # Decode self.reset_policy(labels=labels, z=aug_posterior.sample()) for t in range(actions.size(0)): action_likelihood = self.decode_action( augmented_states[t]) self.log.losses['{}_nll'.format( aug.name)] -= action_likelihood.log_prob( augmented_actions[t]) if self.is_recurrent: self.update_hidden(augmented_states[t], augmented_actions[t]) for lf_idx, lf_name in enumerate(labels_dict): # Train contrastive loss with augmentations and programs. lf_labels = labels_dict[lf_name] if self.loss_params['contrastive_loss_weight'] > 0: self.log.losses[ aug.name + "_contrastive_" + lf_name] = compute_contrastive_loss( posterior.mean, aug_posterior.mean, self.label_decoder_fc_contrastive[lf_idx], labels=lf_labels, temperature=self. loss_params['contrastive_temperature'], base_temperature=self.loss_params[ 'contrastive_base_temperature'], loss_weight=self. loss_params['contrastive_loss_weight']) if self.loss_params['decoding_loss_weight'] > 0: self.log.losses[ aug.name + "_decoded_" + lf_name] = compute_decoding_loss( aug_posterior.mean, lf_labels, self.label_decoder_fc_decoding[lf_idx], lf.categorical, loss_weight=self. loss_params['decoding_loss_weight']) if len(labels_dict) == 0 and self.loss_params[ 'contrastive_loss_weight'] > 0: # Train with unsupervised contrastive loss. self.log.losses[ aug.name + '_contrastive'] = compute_contrastive_loss( posterior.mean, aug_posterior.mean, self.label_decoder_fc_contrastive, temperature=self. loss_params['contrastive_temperature'], base_temperature=self. loss_params['contrastive_base_temperature'], loss_weight=self. loss_params['contrastive_loss_weight']) # Add contrastive loss for cases where there are no augmentations. if (('augmentations' not in self.config.keys() or len(self.config['augmentations']) == 0) and self.loss_params['contrastive_loss_weight'] > 0): if len(labels_dict) == 0: self.log.losses['contrastive'] = compute_contrastive_loss( posterior.mean, posterior.mean, self.label_decoder_fc_contrastive, temperature=self. loss_params['contrastive_temperature'], base_temperature=self. loss_params['contrastive_base_temperature'], loss_weight=self.loss_params['contrastive_loss_weight'] ) elif len(labels_dict) > 0: for lf_idx, lf_name in enumerate(labels_dict): lf_labels = labels_dict[lf_name] self.log.losses[ "contrastive_" + lf_name] = compute_contrastive_loss( posterior.mean, posterior.mean, lf_idx, labels=lf_labels, temperature=self. loss_params['contrastive_temperature'], base_temperature=self. loss_params['contrastive_base_temperature'], loss_weight=self. loss_params['contrastive_loss_weight']) return self.log
def forward(self, states, actions, labels_dict): self.log.reset() assert actions.size(1) + 1 == states.size( 1) # final state has no corresponding action states = states.transpose(0, 1) actions = actions.transpose(0, 1) labels = torch.cat(list(labels_dict.values()), dim=-1) # Encode if "conditional_single_fly_policy_2_to_2" in self.config and self.config[ "conditional_single_fly_policy_2_to_2"]: if self.config["policy_for_fly_1_2_to_2"]: posterior = self.encode(states[:-1, :, 0:2], actions=actions[:, :, 0:2], labels=labels) else: posterior = self.encode(states[:-1, :, 2:4], actions=actions[:, :, 2:4], labels=labels) else: posterior = self.encode(states[:-1], actions=actions, labels=labels) # print(posterior) kld = Normal.kl_divergence(posterior, free_bits=0.0).detach() self.log.metrics['kl_div_true'] = torch.sum(kld) kld = Normal.kl_divergence(posterior, free_bits=1 / self.config['z_dim']) self.log.losses['kl_div'] = torch.sum(kld) # Decode self.reset_policy(labels=labels, z=posterior.sample()) for t in range(actions.size(0)): if "conditional_single_fly_policy_4_to_2" in self.config and self.config[ "conditional_single_fly_policy_4_to_2"]: if self.config["policy_for_fly_1_4_to_2"]: action_likelihood = self.decode_action(states[t]) self.log.losses['nll'] -= action_likelihood.log_prob( actions[t, :, 0:2]) else: action_likelihood = self.decode_action( torch.cat((states[t + 1, :, 0:2], states[t, :, 2:4]), dim=1)) self.log.losses['nll'] -= action_likelihood.log_prob( actions[t, :, 2:4]) elif "conditional_single_fly_policy_2_to_2" in self.config and self.config[ "conditional_single_fly_policy_2_to_2"]: if self.config["policy_for_fly_1_2_to_2"]: if t == 0: action_likelihood = self.decode_action( states[t], actions=torch.Tensor(np.zeros( (actions.size(1), 2)))) else: action_likelihood = self.decode_action( states[t], actions=actions[t - 1, :, 2:4]) self.log.losses['nll'] -= action_likelihood.log_prob( actions[t, :, 0:2]) else: if t == 0: action_likelihood = self.decode_action( torch.cat( (states[t + 1, :, 0:2], states[t, :, 2:4]), dim=1), actions=torch.Tensor(np.zeros( (actions.size(1), 2)))) else: action_likelihood = self.decode_action( torch.cat( (states[t + 1, :, 0:2], states[t, :, 2:4]), dim=1), actions=actions[t, :, 0:2]) self.log.losses['nll'] -= action_likelihood.log_prob( actions[t, :, 2:4]) else: action_likelihood = self.decode_action(states[t]) self.log.losses['nll'] -= action_likelihood.log_prob( actions[t]) if self.is_recurrent: self.update_hidden(states[t], actions[t]) return self.log