def forward(self, states, actions, labels_dict):
        self.log.reset()
        
        assert actions.size(1)+1 == states.size(1) # final state has no corresponding action
        states = states.transpose(0,1)
        actions = actions.transpose(0,1)
        labels = torch.cat(list(labels_dict.values()), dim=-1)

        # Encode
        posterior = self.encode(states[:-1], actions=actions, labels=labels)

        kld = Normal.kl_divergence(posterior, free_bits=0.0).detach()
        self.log.metrics['kl_div_true'] = torch.sum(kld)

        kld = Normal.kl_divergence(posterior, free_bits=1/self.config['z_dim'])
        self.log.losses['kl_div'] = torch.sum(kld)

        # Decode
        self.reset_policy(labels=labels, z=posterior.sample())

        for t in range(actions.size(0)):
            action_likelihood = self.decode_action(states[t])
            self.log.losses['nll'] -= action_likelihood.log_prob(actions[t])
            
            if self.is_recurrent:
                self.update_hidden(states[t], actions[t])

        return self.log
    def forward(self, states, actions, labels_dict, env):
        self.log.reset()
        
        assert actions.size(1)+1 == states.size(1) # final state has no corresponding action
        states = states.transpose(0,1)
        actions = actions.transpose(0,1)
        labels = torch.cat(list(labels_dict.values()), dim=-1)

        # Pretrain dynamics model
        if self.stage == 1:
            self.compute_dynamics_loss(states, actions)
                    
        # Train CTVAE
        elif self.stage >= 2:
            # Encode
            posterior = self.encode(states[:-1], actions=actions, labels=labels)

            kld = Normal.kl_divergence(posterior, free_bits=0.0).detach()
            self.log.metrics['kl_div_true'] = torch.sum(kld)

            kld = Normal.kl_divergence(posterior, free_bits=1/self.config['z_dim'])
            self.log.losses['kl_div'] = torch.sum(kld)

            # Decode
            self.reset_policy(labels=labels, z=posterior.sample())

            for t in range(actions.size(0)):
                action_likelihood = self.decode_action(states[t])
                self.log.losses['nll'] -= action_likelihood.log_prob(actions[t])
                
                if self.is_recurrent:
                    self.update_hidden(states[t], actions[t])

            # Generate rollout w/ dynamics model
            self.reset_policy(labels=labels)
            rollout_states, rollout_actions = self.generate_rollout_with_dynamics(states, horizon=actions.size(0))

            # Maximize mutual information between rollouts and labels
            for lf_idx, lf_name in enumerate(labels_dict):
                lf = self.config['label_functions'][lf_idx]
                lf_labels = labels_dict[lf_name]

                auxiliary = self.discrim_labels(states[:-1], rollout_actions, lf_idx, lf.categorical)
                self.log.losses['{}_mi'.format(lf_name)] = -auxiliary.log_prob(lf_labels)

            # Update dynamics model with n_collect rollouts from environment
            if self.config['n_collect'] > 0:
                self.reset_policy(labels=labels[:1])
                rollout_states_env, rollout_actions_env = self.generate_rollout_with_env(env, horizon=actions.size(0))
                self.compute_dynamics_loss(rollout_states_env.to(labels.device), rollout_actions_env.to(labels.device))

        return self.log
Esempio n. 3
0
    def forward(self, states, actions, labels_dict):
        self.log.reset()

        # Consistency and decoding loss need labels.
        if (self.loss_params['consistency_loss_weight'] > 0
                or self.loss_params['decoding_loss_weight'] > 0):
            assert len(labels_dict) > 0

        assert actions.size(1) + 1 == states.size(
            1)  # final state has no corresponding action
        states = states.transpose(0, 1)
        actions = actions.transpose(0, 1)

        labels = None
        if len(labels_dict) > 0:
            labels = torch.cat(list(labels_dict.values()), dim=-1)

        # Pretrain program approximators, if using consistency loss.
        if self.stage == 1 and self.loss_params['consistency_loss_weight'] > 0:
            for lf_idx, lf_name in enumerate(labels_dict):
                lf = self.config['label_functions'][lf_idx]
                lf_labels = labels_dict[lf_name]
                self.log.losses[lf_name] = compute_label_loss(
                    states[:-1], actions, lf_labels,
                    self.label_approx_birnn[lf_idx],
                    self.label_approx_fc[lf_idx], lf.categorical)

                # Compute label loss with approx
                if self.log_metrics:
                    approx_labels = self.label(states[:-1], actions, lf_idx,
                                               lf.categorical)
                    assert approx_labels.size() == lf_labels.size()
                    self.log.metrics['{}_approx'.format(lf.name)] = torch.sum(
                        approx_labels * lf_labels)

        # Train TVAE with programs.
        elif self.stage >= 2 or not self.loss_params[
                'consistency_loss_weight'] > 0:
            # Encode
            posterior = self.encode(states[:-1],
                                    actions=actions,
                                    labels=labels)

            kld = Normal.kl_divergence(posterior, free_bits=0.0).detach()
            self.log.metrics['kl_div_true'] = torch.sum(kld)

            kld = Normal.kl_divergence(posterior,
                                       free_bits=1 / self.config['z_dim'])
            self.log.losses['kl_div'] = torch.sum(kld)

            # Decode
            self.reset_policy(labels=labels, z=posterior.sample())

            for t in range(actions.size(0)):
                action_likelihood = self.decode_action(states[t])
                self.log.losses['nll'] -= action_likelihood.log_prob(
                    actions[t])

                if self.is_recurrent:
                    self.update_hidden(states[t], actions[t])

            # Add decoding loss.
            if self.loss_params['decoding_loss_weight'] > 0:
                # Compute label loss
                for lf_idx, lf_name in enumerate(labels_dict):
                    lf = self.config['label_functions'][lf_idx]
                    lf_labels = labels_dict[lf_name]
                    self.log.losses["decoded_" +
                                    lf_name] = compute_decoding_loss(
                                        posterior.mean,
                                        lf_labels,
                                        self.label_decoder_fc_decoding[lf_idx],
                                        lf.categorical,
                                        loss_weight=self.
                                        loss_params['decoding_loss_weight'])

            # Generate rollout for consistency loss.
            # Use the posterior to train here.
            if self.loss_params['consistency_loss_weight'] > 0:
                self.reset_policy(
                    labels=labels,
                    z=posterior.sample(),
                    temperature=self.loss_params['consistency_temperature'])

                rollout_states, rollout_actions = self.generate_rollout(
                    states, horizon=actions.size(0))

                # Compute label loss
                for lf_idx, lf_name in enumerate(labels_dict):
                    lf = self.config['label_functions'][lf_idx]
                    lf_labels = labels_dict[lf_name]
                    self.log.losses[lf_name] = compute_label_loss(
                        rollout_states[:-1],
                        rollout_actions,
                        lf_labels,
                        self.label_approx_birnn[lf_idx],
                        self.label_approx_fc[lf_idx],
                        lf.categorical,
                        loss_weight=self.loss_params['consistency_loss_weight']
                    )

                    # Compute label loss with approx
                    if self.log_metrics:
                        approx_labels = self.label(rollout_states[:-1],
                                                   rollout_actions, lf_idx,
                                                   lf.categorical)
                        assert approx_labels.size() == lf_labels.size()
                        self.log.metrics['{}_approx'.format(
                            lf.name)] = torch.sum(approx_labels * lf_labels)

                        # Compute label loss with true LF
                        rollout_lf_labels = lf.label(
                            rollout_states.transpose(0, 1).detach().cpu(),
                            rollout_actions.transpose(0, 1).detach().cpu(),
                            batch=True)
                        assert rollout_lf_labels.size() == lf_labels.size()
                        self.log.metrics['{}_true'.format(
                            lf.name)] = torch.sum(rollout_lf_labels *
                                                  lf_labels.cpu())

            # If augmentations are provided, additionally train with those.
            if 'augmentations' in self.config.keys():
                for aug in self.config['augmentations']:

                    augmented_states, augmented_actions = aug.augment(
                        states.transpose(0, 1),
                        actions.transpose(0, 1),
                        batch=True)

                    augmented_states = augmented_states.transpose(0, 1)
                    augmented_actions = augmented_actions.transpose(0, 1)
                    aug_posterior = self.encode(augmented_states[:-1],
                                                actions=augmented_actions,
                                                labels=labels)

                    kld = Normal.kl_divergence(aug_posterior,
                                               free_bits=0.0).detach()
                    self.log.metrics['{}_kl_div_true'.format(
                        aug.name)] = torch.sum(kld)

                    kld = Normal.kl_divergence(aug_posterior,
                                               free_bits=1 /
                                               self.config['z_dim'])
                    self.log.losses['{}_kl_div'.format(
                        aug.name)] = torch.sum(kld)

                    # Decode
                    self.reset_policy(labels=labels, z=aug_posterior.sample())

                    for t in range(actions.size(0)):
                        action_likelihood = self.decode_action(
                            augmented_states[t])
                        self.log.losses['{}_nll'.format(
                            aug.name)] -= action_likelihood.log_prob(
                                augmented_actions[t])

                        if self.is_recurrent:
                            self.update_hidden(augmented_states[t],
                                               augmented_actions[t])

                    for lf_idx, lf_name in enumerate(labels_dict):

                        # Train contrastive loss with augmentations and programs.
                        lf_labels = labels_dict[lf_name]
                        if self.loss_params['contrastive_loss_weight'] > 0:

                            self.log.losses[
                                aug.name + "_contrastive_" +
                                lf_name] = compute_contrastive_loss(
                                    posterior.mean,
                                    aug_posterior.mean,
                                    self.label_decoder_fc_contrastive[lf_idx],
                                    labels=lf_labels,
                                    temperature=self.
                                    loss_params['contrastive_temperature'],
                                    base_temperature=self.loss_params[
                                        'contrastive_base_temperature'],
                                    loss_weight=self.
                                    loss_params['contrastive_loss_weight'])

                        if self.loss_params['decoding_loss_weight'] > 0:
                            self.log.losses[
                                aug.name + "_decoded_" +
                                lf_name] = compute_decoding_loss(
                                    aug_posterior.mean,
                                    lf_labels,
                                    self.label_decoder_fc_decoding[lf_idx],
                                    lf.categorical,
                                    loss_weight=self.
                                    loss_params['decoding_loss_weight'])

                    if len(labels_dict) == 0 and self.loss_params[
                            'contrastive_loss_weight'] > 0:
                        # Train with unsupervised contrastive loss.
                        self.log.losses[
                            aug.name +
                            '_contrastive'] = compute_contrastive_loss(
                                posterior.mean,
                                aug_posterior.mean,
                                self.label_decoder_fc_contrastive,
                                temperature=self.
                                loss_params['contrastive_temperature'],
                                base_temperature=self.
                                loss_params['contrastive_base_temperature'],
                                loss_weight=self.
                                loss_params['contrastive_loss_weight'])

            # Add contrastive loss for cases where there are no augmentations.
            if (('augmentations' not in self.config.keys()
                 or len(self.config['augmentations']) == 0)
                    and self.loss_params['contrastive_loss_weight'] > 0):

                if len(labels_dict) == 0:
                    self.log.losses['contrastive'] = compute_contrastive_loss(
                        posterior.mean,
                        posterior.mean,
                        self.label_decoder_fc_contrastive,
                        temperature=self.
                        loss_params['contrastive_temperature'],
                        base_temperature=self.
                        loss_params['contrastive_base_temperature'],
                        loss_weight=self.loss_params['contrastive_loss_weight']
                    )

                elif len(labels_dict) > 0:
                    for lf_idx, lf_name in enumerate(labels_dict):

                        lf_labels = labels_dict[lf_name]

                        self.log.losses[
                            "contrastive_" +
                            lf_name] = compute_contrastive_loss(
                                posterior.mean,
                                posterior.mean,
                                lf_idx,
                                labels=lf_labels,
                                temperature=self.
                                loss_params['contrastive_temperature'],
                                base_temperature=self.
                                loss_params['contrastive_base_temperature'],
                                loss_weight=self.
                                loss_params['contrastive_loss_weight'])

        return self.log
    def forward(self, states, actions, labels_dict):
        self.log.reset()

        assert actions.size(1) + 1 == states.size(
            1)  # final state has no corresponding action
        states = states.transpose(0, 1)
        actions = actions.transpose(0, 1)
        labels = torch.cat(list(labels_dict.values()), dim=-1)

        # Encode
        if "conditional_single_fly_policy_2_to_2" in self.config and self.config[
                "conditional_single_fly_policy_2_to_2"]:
            if self.config["policy_for_fly_1_2_to_2"]:
                posterior = self.encode(states[:-1, :, 0:2],
                                        actions=actions[:, :, 0:2],
                                        labels=labels)
            else:
                posterior = self.encode(states[:-1, :, 2:4],
                                        actions=actions[:, :, 2:4],
                                        labels=labels)
        else:
            posterior = self.encode(states[:-1],
                                    actions=actions,
                                    labels=labels)
        # print(posterior)

        kld = Normal.kl_divergence(posterior, free_bits=0.0).detach()
        self.log.metrics['kl_div_true'] = torch.sum(kld)

        kld = Normal.kl_divergence(posterior,
                                   free_bits=1 / self.config['z_dim'])
        self.log.losses['kl_div'] = torch.sum(kld)

        # Decode
        self.reset_policy(labels=labels, z=posterior.sample())

        for t in range(actions.size(0)):
            if "conditional_single_fly_policy_4_to_2" in self.config and self.config[
                    "conditional_single_fly_policy_4_to_2"]:
                if self.config["policy_for_fly_1_4_to_2"]:
                    action_likelihood = self.decode_action(states[t])
                    self.log.losses['nll'] -= action_likelihood.log_prob(
                        actions[t, :, 0:2])
                else:
                    action_likelihood = self.decode_action(
                        torch.cat((states[t + 1, :, 0:2], states[t, :, 2:4]),
                                  dim=1))
                    self.log.losses['nll'] -= action_likelihood.log_prob(
                        actions[t, :, 2:4])
            elif "conditional_single_fly_policy_2_to_2" in self.config and self.config[
                    "conditional_single_fly_policy_2_to_2"]:
                if self.config["policy_for_fly_1_2_to_2"]:
                    if t == 0:
                        action_likelihood = self.decode_action(
                            states[t],
                            actions=torch.Tensor(np.zeros(
                                (actions.size(1), 2))))
                    else:
                        action_likelihood = self.decode_action(
                            states[t], actions=actions[t - 1, :, 2:4])
                    self.log.losses['nll'] -= action_likelihood.log_prob(
                        actions[t, :, 0:2])
                else:
                    if t == 0:
                        action_likelihood = self.decode_action(
                            torch.cat(
                                (states[t + 1, :, 0:2], states[t, :, 2:4]),
                                dim=1),
                            actions=torch.Tensor(np.zeros(
                                (actions.size(1), 2))))
                    else:
                        action_likelihood = self.decode_action(
                            torch.cat(
                                (states[t + 1, :, 0:2], states[t, :, 2:4]),
                                dim=1),
                            actions=actions[t, :, 0:2])
                    self.log.losses['nll'] -= action_likelihood.log_prob(
                        actions[t, :, 2:4])
            else:
                action_likelihood = self.decode_action(states[t])
                self.log.losses['nll'] -= action_likelihood.log_prob(
                    actions[t])

            if self.is_recurrent:
                self.update_hidden(states[t], actions[t])

        return self.log