Esempio n. 1
0
    def _do_training(self):
        batch = self.get_batch(training=True)
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        """
        Policy operations.
        """
        policy_actions, preactivations = self.policy(
            obs,
            goals,
            return_preactivations=True,
        )
        pre_activation_policy_loss = self.pre_activation_weight * (
            (preactivations**2).sum(dim=1).mean())
        q_output = self.qf(obs, policy_actions, goals)
        raw_policy_loss = -q_output.mean()
        policy_loss = (raw_policy_loss +
                       self.pre_activation_weight * pre_activation_policy_loss)
        """
        Critic operations.
        """
        next_actions = self.target_policy(next_obs, goals)
        target_q_values = self.target_qf(
            next_obs,
            next_actions,
            goals,
        )
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()
        q_target = torch.clamp(q_target, -1. / (1 - self.discount), 0)
        q_pred = self.qf(obs, actions, goals)
        bellman_errors = (q_pred - q_target)**2
        qf_loss = self.qf_criterion(q_pred, q_target)
        """
        Update Networks
        """

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self._update_target_networks()

        if self.eval_statistics is None:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics = OrderedDict()
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics['Raw Policy Loss'] = np.mean(
                ptu.get_numpy(raw_policy_loss))
            self.eval_statistics['Pre-activation Policy Loss'] = np.mean(
                ptu.get_numpy(pre_activation_policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors',
                    ptu.get_numpy(bellman_errors),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
Esempio n. 2
0
    def next_state(self, state, action):
        if self._is_structured_qf:
            states = ptu.np_to_var(np.expand_dims(state, 0))
            actions = ptu.np_to_var(np.expand_dims(action, 0))
            discount = ptu.np_to_var(self.discount + np.zeros((1, 1)))
            return ptu.get_numpy(
                self.qf(states, actions, None, discount, True).squeeze(0))

        if self.state_optimizer == 'adam':
            discount = ptu.np_to_var(self.discount * np.ones(
                (self.sample_size, 1)))
            obs_dim = state.shape[0]
            states = self.expand_np_to_var(state)
            actions = self.expand_np_to_var(action)
            next_states_np = np.zeros((self.sample_size, obs_dim))
            next_states = ptu.np_to_var(next_states_np, requires_grad=True)
            optimizer = optim.Adam([next_states], self.learning_rate)

            for _ in range(self.num_optimization_steps):
                losses = -self.qf(
                    states,
                    actions,
                    next_states,
                    discount,
                )
                loss = losses.mean()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            losses_np = ptu.get_numpy(losses)
            best_action_i = np.argmin(losses_np)
            return ptu.get_numpy(next_states[best_action_i, :])
        elif self.state_optimizer == 'lbfgs':
            next_states = []
            for i in range(len(states)):
                state = states[i:i + 1, :]
                action = actions[i:i + 1, :]
                loss_f = self.create_loss(state, action, return_gradient=True)
                results = optimize.fmin_l_bfgs_b(
                    loss_f,
                    np.zeros((1, obs_dim)),
                    maxiter=self.num_optimization_steps,
                )
                next_state = results[0]
                next_states.append(next_state)
            next_states = np.array(next_states)
            return next_states
        elif self.state_optimizer == 'fmin':
            next_states = []
            for i in range(len(states)):
                state = states[i:i + 1, :]
                action = actions[i:i + 1, :]
                loss_f = self.create_loss(state, action)
                results = optimize.fmin(
                    loss_f,
                    np.zeros((1, obs_dim)),
                    maxiter=self.num_optimization_steps,
                )
                next_state = results[0]
                next_states.append(next_state)
            next_states = np.array(next_states)
            return next_states
        else:
            raise Exception("Unknown state optimizer mode: {}".format(
                self.state_optimizer))
 def _decode(self, latents):
     self.vae.eval()
     reconstructions, _ = self.vae.decode(ptu.from_numpy(latents))
     decoded = ptu.get_numpy(reconstructions)
     return decoded
Esempio n. 4
0
 def get_action(self, obs):
     obs = np.expand_dims(obs, axis=0)
     obs = Variable(ptu.from_numpy(obs).float(), requires_grad=False)
     action, _, _ = self.__call__(obs, None)
     action = action.squeeze(0)
     return ptu.get_numpy(action), {}
Esempio n. 5
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']

        q1_pred = self.qf1(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_pred = self.qf2(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        # Make sure policy accounts for squashing functions like tanh correctly!
        policy_outputs = self.policy(obs,
                                     goals,
                                     num_steps_left,
                                     reparameterize=self.train_policy_with_reparameterization,
                                     return_log_prob=True)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]
        if not self.dense_rewards and not self.dense_log_pi:
            log_pi = log_pi * terminals

        """
        QF Loss
        """
        target_v_values = self.target_vf(
            observations=next_obs,
            goals=goals,
            num_steps_left=num_steps_left-1,
        )
        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_v_values
        q_target = q_target.detach()
        bellman_errors_1 = (q1_pred - q_target) ** 2
        bellman_errors_2 = (q2_pred - q_target) ** 2
        qf1_loss = bellman_errors_1.mean()
        qf2_loss = bellman_errors_2.mean()

        if self.use_automatic_entropy_tuning:
            """
            Alpha Loss
            """
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = 1

        """
        VF Loss
        """
        q1_new_actions = self.qf1(
            observations=obs,
            actions=new_actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_new_actions = self.qf2(
            observations=obs,
            actions=new_actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q_new_actions = torch.min(q1_new_actions, q2_new_actions)
        v_target = q_new_actions - alpha * log_pi
        v_pred = self.vf(
            observations=obs,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        v_target = v_target.detach()
        bellman_errors = (v_pred - v_target) ** 2
        vf_loss = bellman_errors.mean()

        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()

        """
        Policy Loss
        """
        # paper says to do + but apparently that's a typo. Do Q - V.
        if self.train_policy_with_reparameterization:
            policy_loss = (alpha * log_pi - q_new_actions).mean()
        else:
            log_policy_target = q_new_actions - v_pred
            policy_loss = (
                log_pi * (alpha * log_pi - log_policy_target).detach()
            ).mean()
        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean ** 2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std ** 2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value ** 2).sum(dim=1).mean()
        )
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        if self._n_train_steps_total % self.policy_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.vf, self.target_vf, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'V Predictions',
                ptu.get_numpy(v_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = ptu.get_numpy(alpha)[0]
                self.eval_statistics['Alpha Loss'] = ptu.get_numpy(alpha_loss)[0]
Esempio n. 6
0
def get_best_action(qfs, observation, t):
    obs = ptu.np_to_var(observation[None], requires_grad=False).float()
    q_values = qfs[t](obs).squeeze(0)
    q_values_np = ptu.get_numpy(q_values)
    return q_values_np.argmax()
Esempio n. 7
0
 def get_param_values_np(self):
     state_dict = self.state_dict()
     np_dict = OrderedDict()
     for key, tensor in state_dict.items():
         np_dict[key] = ptu.get_numpy(tensor)
     return np_dict
Esempio n. 8
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']
        """
        Critic operations.
        """
        next_actions = self.target_policy(
            observations=next_obs,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q2_values = self.target_qf2(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_pred = self.qf2(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )

        bellman_errors_1 = (q1_pred - q_target)**2
        bellman_errors_2 = (q2_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions, pre_tanh_value = self.policy(
            obs,
            goals,
            num_steps_left,
            return_preactivations=True,
        )
        q_output = self.qf1(
            observations=obs,
            actions=policy_actions,
            num_steps_left=num_steps_left,
            goals=goals,
        )

        policy_loss = -q_output.mean()

        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics = OrderedDict()
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman1 Errors',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman2 Errors',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
Esempio n. 9
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Classifier and Policy
        """
        class_actions = self.policy(obs)
        class_prob = self.classifier(obs, actions)
        prob_target = 1 + rewards[:, -1]

        neg_log_prob = - torch.log(self.classifier(obs, class_actions))
        policy_loss = (neg_log_prob).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
Esempio n. 10
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        q_pred = self.qf(obs, actions)
        v_pred = self.vf(obs)
        # Make sure policy accounts for squashing functions like tanh correctly!
        (
            new_actions, policy_mean, policy_log_std, log_pi, entropy,
            policy_stds, log_pi_mean
        ) = self.policy(
            obs,
            return_log_prob=True,
            return_entropy=(
                self.expected_log_pi_estim_strategy == EXACT
            ),
            return_log_prob_of_mean=(
                self.expected_log_pi_estim_strategy == MEAN_ACTION
            ),
        )
        expected_log_pi = - entropy

        """
        QF Loss
        """
        target_v_values = self.target_vf(next_obs)
        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_v_values
        qf_loss = self.qf_criterion(q_pred, q_target.detach())

        """
        VF Loss
        """
        q_new_actions = self.qf(obs, new_actions)
        if self.expected_qf_estim_strategy == EXACT:
            expected_q = self.qf(obs, policy_mean, action_stds=policy_stds)
        elif self.expected_qf_estim_strategy == MEAN_ACTION:
            expected_q = self.qf(obs, policy_mean)
        elif self.expected_qf_estim_strategy == SAMPLE:
            expected_q = q_new_actions
        else:
            raise TypeError("Invalid E[Q(s, a)] estimation strategy: {}".format(
                self.expected_qf_estim_strategy
            ))
        if self.expected_log_pi_estim_strategy == EXACT:
            expected_log_pi_target = expected_log_pi
        elif self.expected_log_pi_estim_strategy == MEAN_ACTION:
            expected_log_pi_target = log_pi_mean
        elif self.expected_log_pi_estim_strategy == SAMPLE:
            expected_log_pi_target = log_pi
        else:
            raise TypeError(
                "Invalid E[log pi(a|s)] estimation strategy: {}".format(
                    self.expected_log_pi_estim_strategy
                )
            )
        v_target = expected_q - expected_log_pi_target
        vf_loss = self.vf_criterion(v_pred, v_target.detach())

        """
        Policy Loss
        """
        # paper says to do + but Tuomas said that's a typo. Do Q - V.
        log_policy_target = q_new_actions - v_pred
        policy_loss = (
            log_pi * (log_pi - log_policy_target).detach()
        ).mean()
        policy_reg_loss = self.policy_reg_weight * (
            (policy_mean ** 2).mean()
            + (policy_log_std ** 2).mean()
        )
        policy_loss = policy_loss + policy_reg_loss

        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self._update_target_network()

        """
        Save some statistics for eval
        """
        self.eval_statistics = OrderedDict()
        self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
        self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
        self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
            policy_loss
        ))
        self.eval_statistics.update(create_stats_ordered_dict(
            'Q Predictions',
            ptu.get_numpy(q_pred),
        ))
        self.eval_statistics.update(create_stats_ordered_dict(
            'V Predictions',
            ptu.get_numpy(v_pred),
        ))
        self.eval_statistics.update(create_stats_ordered_dict(
            'Log Pis',
            ptu.get_numpy(log_pi),
        ))
        self.eval_statistics.update(create_stats_ordered_dict(
            'Policy mu',
            ptu.get_numpy(policy_mean),
        ))
        self.eval_statistics.update(create_stats_ordered_dict(
            'Policy log std',
            ptu.get_numpy(policy_log_std),
        ))
Esempio n. 11
0
    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        obs = obs.reshape((obs.shape[0], ) + (3, 48, 48))
        if not self.discrete:
            actions = batch['actions']
        else:
            actions = batch['actions'].argmax(dim=-1)
        next_obs = batch['next_observations']
        next_obs = next_obs.reshape((next_obs.shape[0], ) + (3, 48, 48))
        """
        Policy and Alpha Loss
        """
        if self.discrete:
            new_obs_actions, pi_probs, log_pi, entropies = self.policy(
                obs, None, return_log_prob=True)
            new_next_actions, pi_next_probs, new_log_pi, next_entropies = self.policy(
                next_obs, None, return_log_prob=True)
            q_vector = self.qf1.q_vector(obs)
            q2_vector = self.qf2.q_vector(obs)
            q_next_vector = self.qf1.q_vector(next_obs)
            q2_next_vector = self.qf2.q_vector(next_obs)
        else:
            new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
                obs,
                None,
                reparameterize=True,
                return_log_prob=True,
            )

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        if not self.discrete:
            if self.num_qs == 1:
                q_new_actions = self.qf1(obs, None, new_obs_actions)
            else:
                q_new_actions = torch.min(
                    self.qf1(obs, None, new_obs_actions),
                    self.qf2(obs, None, new_obs_actions),
                )

        if self.discrete:
            target_q_values = torch.min(q_vector, q2_vector)
            policy_loss = -((target_q_values * pi_probs).sum(dim=-1) +
                            alpha * entropies).mean()
        else:
            policy_loss = (alpha * log_pi - q_new_actions).mean()

        if self._current_epoch < self.policy_eval_start:
            """Start with BC"""
            policy_log_prob = self.policy.log_prob(obs, None, actions)
            policy_loss = (alpha * log_pi - policy_log_prob).mean()
            # print ('Policy Loss: ', policy_loss.item())
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, None, actions)
        if self.num_qs > 1:
            q2_pred = self.qf2(obs, None, actions)

        # Make sure policy accounts for squashing functions like tanh correctly!
        if not self.discrete:
            new_next_actions, _, _, new_log_pi, *_ = self.policy(
                next_obs,
                None,
                reparameterize=True,
                return_log_prob=True,
            )
            new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy(
                obs,
                None,
                reparameterize=True,
                return_log_prob=True,
            )
        else:
            new_curr_actions, pi_curr_probs, new_curr_log_pi, new_curr_entropies = self.policy(
                obs, None, return_log_prob=True)

        if not self.max_q_backup:
            if not self.discrete:
                if self.num_qs == 1:
                    target_q_values = self.target_qf1(next_obs, None,
                                                      new_next_actions)
                else:
                    target_q_values = torch.min(
                        self.target_qf1(next_obs, None, new_next_actions),
                        self.target_qf2(next_obs, None, new_next_actions),
                    )
            else:
                target_q_values = torch.min(
                    (self.target_qf1.q_vector(next_obs) *
                     pi_next_probs).sum(dim=-1),
                    (self.target_qf2.q_vector(next_obs) *
                     pi_next_probs).sum(dim=-1))
                target_q_values = target_q_values.unsqueeze(-1)

            if not self.deterministic_backup:
                target_q_values = target_q_values - alpha * new_log_pi

        if self.max_q_backup:
            """when using max q backup"""
            if not self.discrete:
                next_actions_temp, _ = self._get_policy_actions(
                    next_obs, num_actions=10, network=self.policy)
                target_qf1_values = \
                self._get_tensor_values(next_obs, next_actions_temp,
                                        network=self.target_qf1).max(1)[0].view(
                    -1, 1)
                target_qf2_values = \
                self._get_tensor_values(next_obs, next_actions_temp,
                                        network=self.target_qf2).max(1)[0].view(
                    -1, 1)
                target_q_values = torch.min(
                    target_qf1_values, target_qf2_values
                )  # + torch.max(target_qf1_values, target_qf2_values) * 0.25
            else:
                target_qf1_values = \
                self.target_qf1.q_vector(next_obs).max(dim=-1)[0]
                target_qf2_values = \
                self.target_qf2.q_vector(next_obs).max(dim=-1)[0]
                target_q_values = torch.min(target_qf1_values,
                                            target_qf2_values).unsqueeze(-1)

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values

        # Only detach if we are not using Bellman residual and not otherwise
        if self._use_target_nets:
            q_target = q_target.detach()

        qf1_loss = self.qf_criterion(q1_pred, q_target)
        if self.num_qs > 1:
            qf2_loss = self.qf_criterion(q2_pred, q_target)

        if self.hinge_bellman:
            qf1_loss = self.softplus(q_target - q1_pred).mean()
            qf2_loss = self.softplus(q_target - q2_pred).mean()

        ## add min_q
        if self.with_min_q:
            if not self.discrete:
                random_actions_tensor = torch.FloatTensor(
                    q2_pred.shape[0] * self.num_random,
                    actions.shape[-1]).uniform_(-1, 1).cuda()
                curr_actions_tensor, curr_log_pis = self._get_policy_actions(
                    obs, num_actions=self.num_random, network=self.policy)
                new_curr_actions_tensor, new_log_pis = self._get_policy_actions(
                    next_obs, num_actions=self.num_random, network=self.policy)
                q1_rand = self._get_tensor_values(obs,
                                                  random_actions_tensor,
                                                  network=self.qf1)
                q2_rand = self._get_tensor_values(obs,
                                                  random_actions_tensor,
                                                  network=self.qf2)
                q1_curr_actions = self._get_tensor_values(obs,
                                                          curr_actions_tensor,
                                                          network=self.qf1)
                q2_curr_actions = self._get_tensor_values(obs,
                                                          curr_actions_tensor,
                                                          network=self.qf2)
                q1_next_actions = self._get_tensor_values(
                    obs, new_curr_actions_tensor, network=self.qf1)
                q2_next_actions = self._get_tensor_values(
                    obs, new_curr_actions_tensor, network=self.qf2)

                # q1_next_states_actions = self._get_tensor_values(next_obs, new_curr_actions_tensor, network=self.qf1)
                # q2_next_states_actions = self._get_tensor_values(next_obs, new_curr_actions_tensor, network=self.qf2)

                cat_q1 = torch.cat([
                    q1_rand,
                    q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions
                ], 1)
                cat_q2 = torch.cat([
                    q2_rand,
                    q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions
                ], 1)
                std_q1 = torch.std(cat_q1, dim=1)
                std_q2 = torch.std(cat_q2, dim=1)

                if self.min_q_version == 3:
                    # importance sammpled version
                    random_density = np.log(0.5**curr_actions_tensor.shape[-1])
                    cat_q1 = torch.cat([
                        q1_rand - random_density,
                        q1_next_actions - new_log_pis.detach(),
                        q1_curr_actions - curr_log_pis.detach()
                    ], 1)
                    cat_q2 = torch.cat([
                        q2_rand - random_density,
                        q2_next_actions - new_log_pis.detach(),
                        q2_curr_actions - curr_log_pis.detach()
                    ], 1)

                if self.min_q_version == 0:
                    min_qf1_loss = cat_q1.mean() * self.min_q_weight
                    min_qf2_loss = cat_q2.mean() * self.min_q_weight
                elif self.min_q_version == 1:
                    """Expectation under softmax distribution"""
                    softmax_dist_1 = self.softmax(
                        cat_q1 / self.temp).detach() * self.temp
                    softmax_dist_2 = self.softmax(
                        cat_q2 / self.temp).detach() * self.temp
                    min_qf1_loss = (cat_q1 *
                                    softmax_dist_1).mean() * self.min_q_weight
                    min_qf2_loss = (cat_q2 *
                                    softmax_dist_2).mean() * self.min_q_weight
                elif self.min_q_version == 2 or self.min_q_version == 3:
                    """log sum exp for the min"""
                    min_qf1_loss = torch.logsumexp(
                        cat_q1 / self.temp,
                        dim=1,
                    ).mean() * self.min_q_weight * self.temp
                    min_qf2_loss = torch.logsumexp(
                        cat_q2 / self.temp,
                        dim=1,
                    ).mean() * self.min_q_weight * self.temp

                if self.data_subtract:
                    """Subtract the log likelihood of data"""
                    min_qf1_loss = min_qf1_loss - q1_pred.mean(
                    ) * self.min_q_weight
                    min_qf2_loss = min_qf2_loss - q2_pred.mean(
                    ) * self.min_q_weight
            else:
                q1_policy = (q_vector * pi_probs).sum(dim=-1)
                q2_policy = (q2_vector * pi_probs).sum(dim=-1)
                q1_next_actions = (q_next_vector * pi_next_probs).sum(dim=-1)
                q2_next_actions = (q2_next_vector * pi_next_probs).sum(dim=-1)

                if self.min_q_version == 0:
                    min_qf1_loss = (q1_policy.mean() + q1_next_actions.mean() +
                                    q_vector.mean() + q_next_vector.mean()
                                    ).mean() * self.min_q_weight
                    min_qf2_loss = (q2_policy.mean() + q1_next_actions.mean() +
                                    q2_vector.mean() + q2_next_vector.mean()
                                    ).mean() * self.min_q_weight
                elif self.min_q_version == 1:
                    min_qf1_loss = (q_vector.mean() + q_next_vector.mean()
                                    ).mean() * self.min_q_weight
                    min_qf2_loss = (q2_vector.mean() + q2_next_vector.mean()
                                    ).mean() * self.min_q_weight
                else:
                    softmax_dist_q1 = self.softmax(
                        q_vector / self.temp).detach() * self.temp
                    softmax_dist_q2 = self.softmax(
                        q2_vector / self.temp).detach() * self.temp
                    min_qf1_loss = (q_vector *
                                    softmax_dist_q1).mean() * self.min_q_weight
                    min_qf2_loss = (q2_vector *
                                    softmax_dist_q2).mean() * self.min_q_weight

                if self.data_subtract:
                    min_qf1_loss = min_qf1_loss - q1_pred.mean(
                    ) * self.min_q_weight
                    min_qf2_loss = min_qf2_loss - q2_pred.mean(
                    ) * self.min_q_weight

                std_q1 = torch.std(q_vector, dim=-1)
                std_q2 = torch.std(q2_vector, dim=-1)
                q1_on_policy = q1_policy.mean()
                q2_on_policy = q2_policy.mean()
                q1_random = q_vector.mean()
                q2_random = q2_vector.mean()
                q1_next_actions_mean = q1_next_actions.mean()
                q2_next_actions_mean = q2_next_actions.mean()

            if self.use_projected_grad:
                min_qf1_grad = torch.autograd.grad(
                    min_qf1_loss,
                    inputs=[p for p in self.qf1.parameters()],
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                min_qf2_grad = torch.autograd.grad(
                    min_qf2_loss,
                    inputs=[p for p in self.qf2.parameters()],
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                qf1_loss_grad = torch.autograd.grad(
                    qf1_loss,
                    inputs=[p for p in self.qf1.parameters()],
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                qf2_loss_grad = torch.autograd.grad(
                    qf2_loss,
                    inputs=[p for p in self.qf2.parameters()],
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)

                # this is for the offline setting
                # qf1_total_grad = self.compute_mt_grad(qf1_loss_grad, min_qf1_grad)
                # qf2_total_grad = self.compute_mt_grad(qf2_loss_grad, min_qf2_grad)
                qf1_total_grad = self.compute_new_grad(min_qf1_grad,
                                                       qf1_loss_grad)
                qf2_total_grad = self.compute_new_grad(min_qf2_grad,
                                                       qf2_loss_grad)
            else:
                if self.with_lagrange:
                    alpha_prime = torch.clamp(self.log_alpha_prime.exp(),
                                              min=0,
                                              max=2000000.0)
                    orig_min_qf1_loss = min_qf1_loss
                    orig_min_qf2_loss = min_qf2_loss
                    min_qf1_loss = alpha_prime * (min_qf1_loss -
                                                  self.target_action_gap)
                    min_qf2_loss = alpha_prime * (min_qf2_loss -
                                                  self.target_action_gap)
                    self.alpha_prime_optimizer.zero_grad()
                    alpha_prime_loss = -0.5 * (min_qf1_loss + min_qf2_loss)
                    alpha_prime_loss.backward(retain_graph=True)
                    self.alpha_prime_optimizer.step()
                qf1_loss = qf1_loss + min_qf1_loss
                qf2_loss = qf2_loss + min_qf2_loss
        """
        Update networks
        """
        # Update the Q-functions iff
        self._num_q_update_steps += 1
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        if self.with_min_q and self.use_projected_grad:
            for (p, proj_grad) in zip(self.qf1.parameters(), qf1_total_grad):
                p.grad.data = proj_grad
        self.qf1_optimizer.step()

        if self.num_qs > 1:
            self.qf2_optimizer.zero_grad()
            qf2_loss.backward(retain_graph=True)
            if self.with_min_q and self.use_projected_grad:
                for (p, proj_grad) in zip(self.qf2.parameters(),
                                          qf2_total_grad):
                    p.grad.data = proj_grad
            self.qf2_optimizer.step()

        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        if self._use_target_nets:
            if self._n_train_steps_total % self.target_update_period == 0:
                ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                        self.soft_target_tau)
                if self.num_qs > 1:
                    ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                            self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            if not self.discrete:
                policy_loss = (log_pi - q_new_actions).mean()
            else:
                target_q_values = torch.min(q_vector, q2_vector)
                policy_loss = -((target_q_values * pi_probs).sum(dim=-1) +
                                alpha * entropies).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            if self.num_qs > 1:
                self.eval_statistics['QF2 Loss'] = np.mean(
                    ptu.get_numpy(qf2_loss))

            if self.with_min_q and not self.discrete:
                self.eval_statistics['Std QF1 values'] = np.mean(
                    ptu.get_numpy(std_q1))
                self.eval_statistics['Std QF2 values'] = np.mean(
                    ptu.get_numpy(std_q2))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 in-distribution values',
                        ptu.get_numpy(q1_curr_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 in-distribution values',
                        ptu.get_numpy(q2_curr_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 random values',
                        ptu.get_numpy(q1_rand),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 random values',
                        ptu.get_numpy(q2_rand),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 next_actions values',
                        ptu.get_numpy(q1_next_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 next_actions values',
                        ptu.get_numpy(q2_next_actions),
                    ))
            elif self.with_min_q and self.discrete:
                self.eval_statistics['Std QF1 values'] = np.mean(
                    ptu.get_numpy(std_q1))
                self.eval_statistics['Std QF2 values'] = np.mean(
                    ptu.get_numpy(std_q2))
                self.eval_statistics['QF1 on policy average'] = np.mean(
                    ptu.get_numpy(q1_on_policy))
                self.eval_statistics['QF2 on policy average'] = np.mean(
                    ptu.get_numpy(q2_on_policy))
                self.eval_statistics['QF1 random average'] = np.mean(
                    ptu.get_numpy(q1_random))
                self.eval_statistics['QF2 random average'] = np.mean(
                    ptu.get_numpy(q2_random))
                self.eval_statistics[
                    'QF1 next_actions_mean average'] = np.mean(
                        ptu.get_numpy(q1_next_actions_mean))
                self.eval_statistics[
                    'QF2 next_actions_mean average'] = np.mean(
                        ptu.get_numpy(q2_next_actions_mean))

            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics[
                'Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            if self.num_qs > 1:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Q2 Predictions',
                        ptu.get_numpy(q2_pred),
                    ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            if not self.discrete:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy mu',
                        ptu.get_numpy(policy_mean),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy log std',
                        ptu.get_numpy(policy_log_std),
                    ))
            else:
                self.eval_statistics['Policy entropy'] = ptu.get_numpy(
                    entropies).mean()

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.with_lagrange:
                self.eval_statistics['Alpha Prime'] = alpha_prime.item()
                self.eval_statistics[
                    'Alpha Prime Loss'] = alpha_prime_loss.item()
                self.eval_statistics['Min Q1 Loss'] = orig_min_qf1_loss.item()
                self.eval_statistics['Min Q2 Loss'] = orig_min_qf2_loss.item()

        self._n_train_steps_total += 1
Esempio n. 12
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']

        q_pred = self.qf(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        v_pred = self.vf(
            observations=obs,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        # Check policy accounts for squashing functions like tanh correctly!
        policy_outputs = self.policy(
            observations=obs,
            goals=goals,
            num_steps_left=num_steps_left,
            return_log_prob=True,
        )
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]
        """
        QF Loss
        """
        target_v_values = self.target_vf(
            next_obs,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_v_values
        if self.give_terminal_reward:
            terminal_rewards = self.terminal_bonus * num_steps_left
            q_target = q_target + terminals * terminal_rewards
        qf_loss = self.qf_criterion(q_pred, q_target.detach())
        """
        VF Loss
        """
        q_new_actions = self.qf(
            observations=obs,
            actions=new_actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        v_target = q_new_actions - log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        """
        Policy Loss
        """
        # paper says to do + but apparently that's a typo. Do Q - V.
        log_policy_target = q_new_actions - v_pred
        policy_loss = (log_pi * (log_pi - log_policy_target).detach()).mean()
        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value**2).sum(dim=1).mean())
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss
        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self._update_target_network()

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V Predictions',
                    ptu.get_numpy(v_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
Esempio n. 13
0
def np_ify(tensor_or_other):
    if isinstance(tensor_or_other, torch.autograd.Variable):
        return ptu.get_numpy(tensor_or_other)
    else:
        return tensor_or_other
Esempio n. 14
0
 def _reconstruct_img(self, flat_img):
     zs = self.vae.encode(ptu.np_to_var(flat_img[None]))[0]
     imgs = ptu.get_numpy(self.vae.decode(zs))
     imgs = imgs.reshape(1, self.input_channels, self.imsize, self.imsize)
     return imgs[0]
Esempio n. 15
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']
        """
        Critic operations.
        """
        next_actions = self.target_policy(
            observations=next_obs,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q2_values = self.target_qf2(
            observations=next_obs,
            actions=noisy_next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        q2_pred = self.qf2(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )

        bellman_errors_1 = (q1_pred - q_target)**2
        bellman_errors_2 = (q2_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions, pre_tanh_value = self.policy(
            obs,
            goals,
            num_steps_left,
            return_preactivations=True,
        )
        policy_saturation_cost = F.relu(torch.abs(pre_tanh_value) - 20.0)
        q_output = self.qf1(
            observations=obs,
            actions=policy_actions,
            num_steps_left=num_steps_left,
            goals=goals,
        )

        policy_loss = -q_output.mean()
        if self.use_policy_saturation_cost:
            policy_loss = policy_loss + policy_saturation_cost.mean()

        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics = OrderedDict()
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman1 Errors',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman2 Errors',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Saturation Cost',
                    ptu.get_numpy(policy_saturation_cost),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                    exclude_abs=False,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action Pre-tanh',
                    ptu.get_numpy(pre_tanh_value),
                    exclude_abs=False,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Q Output',
                    ptu.get_numpy(q_output),
                    exclude_abs=False,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Noisy Next Actions',
                    ptu.get_numpy(noisy_next_actions),
                    exclude_abs=False,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Replay Buffer Actions',
                    ptu.get_numpy(actions),
                    exclude_abs=False,
                ))

            be_np = ptu.get_numpy(bellman_errors_1)
            num_steps_left_np = ptu.get_numpy(num_steps_left)

            ### tau == 0 ###
            idx_0 = np.argwhere(num_steps_left_np == 0)
            be_0 = 0
            if len(idx_0) > 0:
                be_0 = be_np[idx_0[:, 0]]
            self.eval_statistics['QF1 Loss tau=0'] = np.mean(be_0)

            ### tau == 1 ###
            idx_1 = np.argwhere(num_steps_left_np == 1)
            be_1 = 0
            if len(idx_1) > 0:
                be_1 = be_np[idx_1[:, 0]]
            self.eval_statistics['QF1 Loss tau=1'] = np.mean(be_1)

            ### tau in [2, 5) ###
            idx_2_to_5 = np.argwhere(
                np.logical_and(num_steps_left_np >= 2, num_steps_left_np < 5))
            be_2_to_5 = 0
            if len(idx_2_to_5) > 0:
                be_2_to_5 = be_np[idx_2_to_5[:, 0]]
            self.eval_statistics['QF1 Loss tau=2_to_5'] = np.mean(be_2_to_5)

            ### tau in [5, 10) ###
            idx_5_to_10 = np.argwhere(
                np.logical_and(num_steps_left_np >= 5, num_steps_left_np < 10))
            be_5_to_10 = 0
            if len(idx_5_to_10) > 0:
                be_5_to_10 = be_np[idx_5_to_10[:, 0]]
            self.eval_statistics['QF1 Loss tau=5_to_10'] = np.mean(be_5_to_10)

            ### tau in [10, max_tau] ###
            idx_10_to_end = np.argwhere(
                np.logical_and(num_steps_left_np >= 10,
                               num_steps_left_np < self.max_tau + 1))
            be_10_to_end = 0
            if len(idx_10_to_end) > 0:
                be_10_to_end = be_np[idx_10_to_end[:, 0]]
            self.eval_statistics['QF1 Loss tau=10_to_end'] = np.mean(
                be_10_to_end)
Esempio n. 16
0
    def _train_given_data(
        self,
        rewards,
        terminals,
        obs,
        actions,
        next_obs,
        logger_prefix="",
    ):
        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target)**2
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = -q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = -q_output.mean()

            self.eval_statistics[logger_prefix + 'QF1 Loss'] = np.mean(
                ptu.get_numpy(qf1_loss))
            self.eval_statistics[logger_prefix + 'QF2 Loss'] = np.mean(
                ptu.get_numpy(qf2_loss))
            self.eval_statistics[logger_prefix + 'Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Bellman Errors 1',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Bellman Errors 2',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    logger_prefix + 'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
    def _do_training(self):
        losses = []
        if self.collection_mode == 'batch':
            """
            Batch mode we'll assume you want to do epoch-style training
            """
            all_obs = self.replay_buffer._observations[:self.replay_buffer.
                                                       _top]
            all_actions = self.replay_buffer._actions[:self.replay_buffer._top]
            all_next_obs = self.replay_buffer._next_obs[:self.replay_buffer.
                                                        _top]

            num_batches = len(all_obs) // self.batch_size
            idx = np.asarray(range(len(all_obs)))
            np.random.shuffle(idx)
            for bn in range(num_batches):
                idxs = idx[bn * self.batch_size:(bn + 1) * self.batch_size]
                obs = all_obs[idxs]
                actions = all_actions[idxs]
                next_obs = all_next_obs[idxs]

                obs = ptu.np_to_var(obs, requires_grad=False)
                actions = ptu.np_to_var(actions, requires_grad=False)
                next_obs = ptu.np_to_var(next_obs, requires_grad=False)

                ob_deltas_pred = self.model(obs, actions)
                ob_deltas = next_obs - obs
                if self.delta_normalizer:
                    normalized_errors = (
                        self.delta_normalizer.normalize(ob_deltas_pred) -
                        self.delta_normalizer.normalize(ob_deltas))
                    squared_errors = normalized_errors**2
                else:
                    squared_errors = (ob_deltas_pred - ob_deltas)**2
                loss = squared_errors.mean()

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                losses.append(ptu.get_numpy(loss))
        else:
            batch = self.get_batch()
            obs = batch['observations']
            actions = batch['actions']
            next_obs = batch['next_observations']
            ob_deltas_pred = self.model(obs, actions)
            ob_deltas = next_obs - obs
            if self.delta_normalizer:
                normalized_errors = (
                    self.delta_normalizer.normalize(ob_deltas_pred) -
                    self.delta_normalizer.normalize(ob_deltas))
                squared_errors = normalized_errors**2
            else:
                squared_errors = (ob_deltas_pred - ob_deltas)**2
            loss = squared_errors.mean()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            losses.append(ptu.get_numpy(loss))

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Model Loss',
                    losses,
                    always_show_all_stats=True,
                    exclude_max_min=True,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Obs Deltas',
                    ptu.get_numpy(ob_deltas),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Predicted Obs Deltas',
                    ptu.get_numpy(ob_deltas_pred),
                ))
Esempio n. 18
0
    def train_policy(self, subtraj_batch, start_indices):
        policy_dict = self.get_policy_output_dict(subtraj_batch)

        policy_loss = policy_dict['Loss']
        qf_loss = 0
        if self.save_memory_gradients:
            dloss_dlast_writes = subtraj_batch['dloss_dwrites'][:, -1, :]
            new_last_writes = policy_dict['New Writes'][:, -1, :]
            qf_loss += (dloss_dlast_writes * new_last_writes).sum()
        if self.use_action_policy_params_for_entire_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward(
                retain_variables=self.action_policy_optimize_bellman
            )
            if self.action_policy_optimize_bellman:
                bellman_errors = policy_dict['Bellman Errors']
                qf_loss += (
                    self.bellman_error_loss_weight * bellman_errors.mean()
                )
                qf_loss.backward()
            self.policy_optimizer.step()
        else:
            self.action_policy_optimizer.zero_grad()
            self.write_policy_optimizer.zero_grad()
            policy_loss.backward(retain_variables=True)

            if self.write_policy_optimizes == 'qf':
                self.write_policy_optimizer.step()
                if self.action_policy_optimize_bellman:
                    bellman_errors = policy_dict['Bellman Errors']
                    qf_loss += (
                        self.bellman_error_loss_weight * bellman_errors.mean()
                    )
                    qf_loss.backward()
                self.action_policy_optimizer.step()
            else:
                if self.write_policy_optimizes == 'bellman':
                    self.write_policy_optimizer.zero_grad()
                if self.action_policy_optimize_bellman:
                    bellman_errors = policy_dict['Bellman Errors']
                    qf_loss += (
                        self.bellman_error_loss_weight * bellman_errors.mean()
                    )
                    qf_loss.backward()
                    self.action_policy_optimizer.step()
                else:
                    self.action_policy_optimizer.step()
                    bellman_errors = policy_dict['Bellman Errors']
                    qf_loss += (
                        self.bellman_error_loss_weight * bellman_errors.mean()
                    )
                    qf_loss.backward()
                self.write_policy_optimizer.step()

        if self.save_new_memories_back_to_replay_buffer:
            self.replay_buffer.train_replay_buffer.update_write_subtrajectories(
                ptu.get_numpy(policy_dict['New Writes']), start_indices
            )
        if self.save_memory_gradients:
            new_dloss_dmemory = ptu.get_numpy(self.saved_grads['dl_dmemory'])
            self.replay_buffer.train_replay_buffer.update_dloss_dmemories_subtrajectories(
                new_dloss_dmemory, start_indices
            )
Esempio n. 19
0
def np_ify(tensor_or_other):
    if isinstance(tensor_or_other, ptu.TorchVariable):
        return ptu.get_numpy(tensor_or_other)
    else:
        return tensor_or_other
Esempio n. 20
0
    def compute_loss(
            self, batch, update_eval_statistics=False,
    ) -> SACLosses:
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        policy_loss = (alpha*log_pi - q_new_actions).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        """
        Save some statistics for eval
        """
        if update_eval_statistics:
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

        return SACLosses(
            policy_loss=policy_loss,
            qf1_loss=qf1_loss,
            qf2_loss=qf2_loss,
            alpha_loss=alpha_loss,
        )
 def encode(imgs, x_0):
     latent = vae.encode(ptu.from_numpy(imgs), x_0, distrib=False)
     return ptu.get_numpy(latent)
Esempio n. 22
0
 def _get_policy_action(self, observation, t):
     obs = ptu.np_to_var(observation[None], requires_grad=False).float()
     return ptu.get_numpy(self.policies[t](obs).squeeze(0))
Esempio n. 23
0
 def _encode(self, imgs):
     #MAKE FLOAT
     self.vae.eval()
     latents = self.vae.encode(ptu.from_numpy(imgs), cont=True)
     latents = np.array(ptu.get_numpy(latents))
     return latents
Esempio n. 24
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Compute loss
        """
        qf_losses = []
        policy_losses = []
        for t in range(self.max_horizon):
            if t == self.max_horizon - 1:
                q_target = self.reward_scale * rewards
            else:
                target_q_values = self.qfs[t + 1](
                    next_obs,
                    self.policies[t + 1](next_obs),
                )
                q_target = (self.reward_scale * rewards +
                            (1. - terminals) * self.discount * target_q_values)
            q_pred = self.qfs[t](obs, actions)
            qf_loss = self.qf_criterion(q_pred, q_target.detach())

            policy_loss = -self.qfs[t](obs, self.policies[t](obs)).mean()

            self.qf_optimizers[t].zero_grad()
            qf_loss.backward()
            self.qf_optimizers[t].step()

            self.policy_optimizers[t].zero_grad()
            policy_loss.backward()
            self.policy_optimizers[t].step()
            """
            Save some statistics for eval
            """
            if self.need_to_update_eval_statistics:
                qf_loss_np = np.mean(ptu.get_numpy(qf_loss))
                self.eval_statistics['QF {} Loss'.format(t)] = qf_loss_np
                policy_loss_np = ptu.get_numpy(policy_loss)
                self.eval_statistics['Policy {} Loss'.format(t)] = (
                    policy_loss_np)
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Q {} Predictions'.format(t),
                        ptu.get_numpy(q_pred),
                    ))
                qf_losses.append(qf_loss_np)
                policy_losses.append(policy_loss_np)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Loss (all nets)',
                    qf_losses,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Loss (all nets)',
                    policy_losses,
                ))
        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
Esempio n. 25
0
 def eval_model_np(state, action):
     state = ptu.Variable(ptu.FloatTensor([[state]]), requires_grad=False)
     action = ptu.Variable(ptu.FloatTensor([[action]]), requires_grad=False)
     a, v = model(state, action)
     q = a + v
     return ptu.get_numpy(q)[0]
Esempio n. 26
0
    def _get_action(self, current_ob, goal):
        if (
                self.replan_every_time_step
                or self.t_in_plan == self.planning_horizon
                or self.last_solution is None
        ):
            full_solution = self.replan(current_ob, goal)

            x_torch = ptu.np_to_var(full_solution, requires_grad=True)
            current_ob_torch = ptu.np_to_var(current_ob)

            _, actions, next_obs = self.batchify(x_torch, current_ob_torch)
            self.subgoal_seq = np.array(
                [current_ob] + [ptu.get_numpy(o) for o in next_obs]
            )
            self.learned_actions = self.learned_policy.eval_np(
                self.subgoal_seq[:-1],
                self.subgoal_seq[1:],
                np.zeros((self.planning_horizon, 1))
            )
            self.lbfgs_actions = np.array([ptu.get_numpy(a) for a in actions])
            if self.use_learned_policy:
                self.planned_action_seq = self.learned_actions
            else:
                self.planned_action_seq = self.lbfgs_actions

            self.last_solution = full_solution
            self.t_in_plan = 0

        action = self.planned_action_seq[self.t_in_plan]
        new_goal = self.subgoal_seq[self.t_in_plan+1]
        self._current_goal = new_goal
        oracle_qmax_action = self.get_oracle_qmax_action(current_ob,
                                                         new_goal)
        if self.use_oracle_argmax_policy:
            action = oracle_qmax_action

        # self.cost_function(full_solution, current_ob, verbose=True)
        # adam_action = self.choose_action_to_reach_adam(current_ob, new_goal)
        # lbfgs_action_again = self.choose_action_to_reach_lbfgs_again(
        #     current_ob, new_goal
        # )
        # lbfgs_action = self.lbfgs_actions[self.t_in_plan]
        # learned_action = self.learned_actions[self.t_in_plan]
        # print("---")
        # print("learned action", learned_action)
        # print("\terror: {}".format(np.linalg.norm(learned_action-oracle_qmax_ac)))tion
        # print("lbfgs action", lbfgs_action)
        # print("\terror: {}".format(np.linalg.norm(lbfgs_action-oracle_qmax_ac)))tion
        # print("lbfgs again action", lbfgs_action_again)
        # print("\terror: {}".format(np.linalg.norm(lbfgs_action_again-oracle_qmax_ac)))tion
        # print("adam_action", adam_action)
        # print("\terror: {}".format(np.linalg.norm(adam_action-oracle_qmax_ac)))tion
        # print("oracle best action", oracle_action)
        # print("action", action)
        agent_info = dict(
            planned_action_seq=self.planned_action_seq[self.t_in_plan:],
            subgoal_seq=self.subgoal_seq[self.t_in_plan:],
            oracle_qmax_action=oracle_qmax_action,
            learned_action=self.learned_actions[self.t_in_plan],
            lbfgs_action_seq=self.lbfgs_actions,
            learned_action_seq=self.learned_actions,
            full_action_seq=self.planned_action_seq,
            full_obs_seq=self.subgoal_seq,
        )

        self.t_in_plan += 1
        return action, agent_info
Esempio n. 27
0
 def encode(self, x):
     x_torch = ptu.from_numpy(x)
     embedding_torch = self._mlp(x_torch)
     return ptu.get_numpy(embedding_torch)
Esempio n. 28
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy operations.
        """
        policy_actions = self.policy(obs)
        sampled_actions = ptu.Variable(torch.randn(
            policy_actions.shape)) * self.sample_std + policy_actions
        sampled_actions = sampled_actions.detach()
        deviations = (policy_actions - sampled_actions)**2
        avg_deviations = deviations.mean(dim=1, keepdim=True)
        policy_loss = (
            avg_deviations *
            (self.qf(obs, sampled_actions) - self.target_vf(obs))).mean()
        """
        Qf operations.
        """
        target_q_values = self.target_vf(next_obs)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()
        q_pred = self.qf(obs, actions)
        bellman_errors = (q_pred - q_target)**2
        qf_loss = self.qf_criterion(q_pred, q_target)
        """
        Vf operations.
        """
        v_target = self.qf(obs, self.policy(obs)).detach()
        v_pred = self.vf(obs)
        vf_loss = self.vf_criterion(v_pred, v_target)
        """
        Update Networks
        """

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()

        self._update_target_networks()

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V Predictions',
                    ptu.get_numpy(v_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V Targets',
                    ptu.get_numpy(v_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors',
                    ptu.get_numpy(bellman_errors),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
Esempio n. 29
0
 def _encode(self, imgs):
     self.vae.eval()
     latent_distribution_params = self.vae.encode(ptu.from_numpy(imgs))
     return ptu.get_numpy(latent_distribution_params[0])
Esempio n. 30
0
    def fit(self, data, weights=None):
        if weights is None:
            weights = np.ones(len(data))
        sum_of_weights = weights.flatten().sum()
        weights = weights / sum_of_weights
        all_weights_pt = ptu.from_numpy(weights)

        indexed_train_data = IndexedData(data)
        if self.skew_sampling:
            base_sampler = WeightedRandomSampler(weights, len(weights))
        else:
            base_sampler = RandomSampler(indexed_train_data)

        train_dataloader = DataLoader(
            indexed_train_data,
            sampler=BatchSampler(
                base_sampler,
                batch_size=self.batch_size,
                drop_last=False,
            ),
        )
        if self.reset_vae_every_epoch:
            raise NotImplementedError()

        epoch_stats_list = defaultdict(list)
        for _ in range(self.num_inner_vae_epochs):
            for _, indexed_batch in enumerate(train_dataloader):
                idxs, batch = indexed_batch
                batch = batch[0].float().to(ptu.device)

                latents, means, log_vars, stds = (
                    self.encoder.get_encoding_and_suff_stats(
                        batch
                    )
                )
                beta = 1
                kl = self.kl_to_prior(means, log_vars, stds)
                reconstruction_log_prob = self.compute_log_prob(
                    batch, self.decoder, latents
                )

                elbo = - kl * beta + reconstruction_log_prob
                if self.weight_loss:
                    idxs = torch.cat(idxs)
                    batch_weights = all_weights_pt[idxs].unsqueeze(1)
                    loss = -(batch_weights * elbo).sum()
                else:
                    loss = - elbo.mean()
                self.encoder_opt.zero_grad()
                self.decoder_opt.zero_grad()
                loss.backward()
                self.encoder_opt.step()
                self.decoder_opt.step()

                epoch_stats_list['losses'].append(ptu.get_numpy(loss))
                epoch_stats_list['kls'].append(ptu.get_numpy(kl.mean()))
                epoch_stats_list['log_probs'].append(
                    ptu.get_numpy(reconstruction_log_prob.mean())
                )
                epoch_stats_list['latent-mean'].append(
                    ptu.get_numpy(latents.mean())
                )
                epoch_stats_list['latent-std'].append(
                    ptu.get_numpy(latents.std())
                )
                for k, v in create_stats_ordered_dict(
                    'weights',
                    ptu.get_numpy(all_weights_pt)
                ).items():
                    epoch_stats_list[k].append(v)

        self._epoch_stats = {
            'unnormalized weight sum': sum_of_weights,
        }
        for k in epoch_stats_list:
            self._epoch_stats[k] = np.mean(epoch_stats_list[k])