Exemplo n.º 1
0
class AWACTrainer(TorchTrainer):
    def __init__(
        self,
        env,
        policy,
        qf1,
        qf2,
        target_qf1,
        target_qf2,
        buffer_policy=None,
        discount=0.99,
        reward_scale=1.0,
        beta=1.0,
        beta_schedule_kwargs=None,
        policy_lr=1e-3,
        qf_lr=1e-3,
        policy_weight_decay=0,
        q_weight_decay=0,
        optimizer_class=optim.Adam,
        soft_target_tau=1e-2,
        target_update_period=1,
        plotter=None,
        render_eval_paths=False,
        use_automatic_entropy_tuning=True,
        target_entropy=None,
        bc_num_pretrain_steps=0,
        q_num_pretrain1_steps=0,
        q_num_pretrain2_steps=0,
        bc_batch_size=128,
        alpha=1.0,
        policy_update_period=1,
        q_update_period=1,
        weight_loss=True,
        compute_bc=True,
        use_awr_update=True,
        use_reparam_update=False,
        bc_weight=0.0,
        rl_weight=1.0,
        reparam_weight=1.0,
        awr_weight=1.0,
        post_pretrain_hyperparams=None,
        post_bc_pretrain_hyperparams=None,
        awr_use_mle_for_vf=False,
        vf_K=1,
        awr_sample_actions=False,
        buffer_policy_sample_actions=False,
        awr_min_q=False,
        brac=False,
        reward_transform_class=None,
        reward_transform_kwargs=None,
        terminal_transform_class=None,
        terminal_transform_kwargs=None,
        pretraining_logging_period=1000,
        train_bc_on_rl_buffer=False,
        use_automatic_beta_tuning=False,
        beta_epsilon=1e-10,
        normalize_over_batch=True,
        normalize_over_state="advantage",
        Z_K=10,
        clip_score=None,
        validation_qlearning=False,
        mask_positive_advantage=False,
        buffer_policy_reset_period=-1,
        num_buffer_policy_train_steps_on_reset=100,
        advantage_weighted_buffer_loss=True,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.buffer_policy = buffer_policy
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_awr_update = use_awr_update
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(
                    self.env.action_space.shape).item(
                    )  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.awr_use_mle_for_vf = awr_use_mle_for_vf
        self.vf_K = vf_K
        self.awr_sample_actions = awr_sample_actions
        self.awr_min_q = awr_min_q

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.optimizers = {}

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            weight_decay=policy_weight_decay,
            lr=policy_lr,
        )
        self.optimizers[self.policy] = self.policy_optimizer
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )

        if buffer_policy and train_bc_on_rl_buffer:
            self.buffer_policy_optimizer = optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=policy_weight_decay,
                lr=policy_lr,
            )
            self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer
            self.optimizer_class = optimizer_class
            self.policy_weight_decay = policy_weight_decay
            self.policy_lr = policy_lr

        self.use_automatic_beta_tuning = use_automatic_beta_tuning and buffer_policy and train_bc_on_rl_buffer
        self.beta_epsilon = beta_epsilon
        if self.use_automatic_beta_tuning:
            self.log_beta = ptu.zeros(1, requires_grad=True)
            self.beta_optimizer = optimizer_class(
                [self.log_beta],
                lr=policy_lr,
            )
        else:
            self.beta = beta
            self.beta_schedule_kwargs = beta_schedule_kwargs
            if beta_schedule_kwargs is None:
                self.beta_schedule = ConstantSchedule(beta)
            else:
                schedule_class = beta_schedule_kwargs.pop(
                    "schedule_class", PiecewiseLinearSchedule)
                self.beta_schedule = schedule_class(**beta_schedule_kwargs)

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.bc_num_pretrain_steps = bc_num_pretrain_steps
        self.q_num_pretrain1_steps = q_num_pretrain1_steps
        self.q_num_pretrain2_steps = q_num_pretrain2_steps
        self.bc_batch_size = bc_batch_size
        self.rl_weight = rl_weight
        self.bc_weight = bc_weight
        self.eval_policy = MakeDeterministic(self.policy)
        self.compute_bc = compute_bc
        self.alpha = alpha
        self.q_update_period = q_update_period
        self.policy_update_period = policy_update_period
        self.weight_loss = weight_loss

        self.reparam_weight = reparam_weight
        self.awr_weight = awr_weight
        self.post_pretrain_hyperparams = post_pretrain_hyperparams
        self.post_bc_pretrain_hyperparams = post_bc_pretrain_hyperparams
        self.update_policy = True
        self.pretraining_logging_period = pretraining_logging_period
        self.normalize_over_batch = normalize_over_batch
        self.normalize_over_state = normalize_over_state
        self.Z_K = Z_K

        self.reward_transform_class = reward_transform_class or LinearTransform
        self.reward_transform_kwargs = reward_transform_kwargs or dict(m=1,
                                                                       b=0)
        self.terminal_transform_class = terminal_transform_class or LinearTransform
        self.terminal_transform_kwargs = terminal_transform_kwargs or dict(m=1,
                                                                           b=0)
        self.reward_transform = self.reward_transform_class(
            **self.reward_transform_kwargs)
        self.terminal_transform = self.terminal_transform_class(
            **self.terminal_transform_kwargs)
        self.use_reparam_update = use_reparam_update
        self.clip_score = clip_score
        self.buffer_policy_sample_actions = buffer_policy_sample_actions

        self.train_bc_on_rl_buffer = train_bc_on_rl_buffer and buffer_policy
        self.validation_qlearning = validation_qlearning
        self.brac = brac
        self.mask_positive_advantage = mask_positive_advantage
        self.buffer_policy_reset_period = buffer_policy_reset_period
        self.num_buffer_policy_train_steps_on_reset = num_buffer_policy_train_steps_on_reset
        self.advantage_weighted_buffer_loss = advantage_weighted_buffer_loss

    @staticmethod
    def get_batch_from_buffer(replay_buffer, batch_size):
        """

        :param replay_buffer:
        :param batch_size:
        :return:
        """
        batch = replay_buffer.random_batch(batch_size)
        batch = np_to_pytorch_batch(batch)
        return batch

    def run_bc_batch(self, replay_buffer, policy):
        """Get a batch from the replay buffer and run the policy on it.
        Return 3 losses and policy statistics
        bc stands for behavior cloning?

        :param replay_buffer:
        :param policy:
        :return:
        """
        batch = self.get_batch_from_buffer(replay_buffer, self.bc_batch_size)
        o = batch["observations"]
        u = batch["actions"]
        # g = batch["resampled_goals"]
        # og = torch.cat((o, g), dim=1)
        og = o
        # pred_u, *_ = self.policy(og)
        dist = policy(og)
        pred_u, log_pi = dist.rsample_and_logprob()
        stats = dist.get_diagnostics()

        mse = (pred_u - u)**2
        mse_loss = mse.mean()

        policy_logpp = dist.log_prob(u, )
        logp_loss = -policy_logpp.mean()
        policy_loss = logp_loss

        return policy_loss, logp_loss, mse_loss, stats

    def pretrain_policy_with_bc(
        self,
        policy,
        train_buffer,
        test_buffer,
        steps,
        label="policy",
    ):
        """Given a policy, first get its optimizer, then run the policy on the train buffer, get the
        losses, and back propagate the loss. After training on a batch, test on the test buffer and
        get the statistics

        :param policy:
        :param train_buffer:
        :param test_buffer:
        :param steps:
        :param label:
        :return:
        """
        logger.remove_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )

        optimizer = self.optimizers[policy]
        prev_time = time.time()
        for i in range(steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_stats = self.run_bc_batch(
                train_buffer, policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            optimizer.zero_grad()
            train_policy_loss.backward()
            optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_stats = self.run_bc_batch(
                test_buffer, policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if i % self.pretraining_logging_period == 0:
                stats = {
                    "pretrain_bc/batch":
                    i,
                    "pretrain_bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "pretrain_bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "pretrain_bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "pretrain_bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "pretrain_bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "pretrain_bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                    "pretrain_bc/epoch_time":
                    time.time() - prev_time,
                }

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(
                    self.policy,
                    open(logger.get_snapshot_dir() + '/bc_%s.pkl' % label,
                         "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)

    def pretrain_q_with_bc_data(self):
        """

        :return:
        """
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain1_steps):
            self.eval_statistics = dict()

            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)
            if i % self.pretraining_logging_period == 0:
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.q_num_pretrain2_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()

        if self.post_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_pretrain_hyperparams)

    def set_algorithm_weights(self, **kwargs):
        for key in kwargs:
            self.__dict__[key] = kwargs[key]

    def test_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        policy_loss = (log_pi - q_new_actions).mean()

        self.eval_statistics['validation/QF1 Loss'] = np.mean(
            ptu.get_numpy(qf1_loss))
        self.eval_statistics['validation/QF2 Loss'] = np.mean(
            ptu.get_numpy(qf2_loss))
        self.eval_statistics['validation/Policy Loss'] = np.mean(
            ptu.get_numpy(policy_loss))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q Targets',
                ptu.get_numpy(q_target),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Log Pis',
                ptu.get_numpy(log_pi),
            ))
        policy_statistics = add_prefix(dist.get_diagnostics(),
                                       "validation/policy/")
        self.eval_statistics.update(policy_statistics)

    def train_from_torch(
        self,
        batch,
        train=True,
        pretrain=False,
    ):
        """

        :param batch:
        :param train:
        :param pretrain:
        :return:
        """
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.brac:
            buf_dist = self.buffer_policy(obs)
            buf_log_pi = buf_dist.log_prob(actions)
            rewards = rewards + buf_log_pi

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.awr_use_mle_for_vf:
            v1_pi = self.qf1(obs, policy_mle)
            v2_pi = self.qf2(obs, policy_mle)
            v_pi = torch.min(v1_pi, v2_pi)
        else:
            if self.vf_K > 1:
                vs = []
                for i in range(self.vf_K):
                    u = dist.sample()
                    q1 = self.qf1(obs, u)
                    q2 = self.qf2(obs, u)
                    v = torch.min(q1, q2)
                    # v = q1
                    vs.append(v)
                v_pi = torch.cat(vs, 1).mean(dim=1)
            else:
                # v_pi = self.qf1(obs, new_obs_actions)
                v1_pi = self.qf1(obs, new_obs_actions)
                v2_pi = self.qf2(obs, new_obs_actions)
                v_pi = torch.min(v1_pi, v2_pi)

        if self.awr_sample_actions:
            u = new_obs_actions
            if self.awr_min_q:
                q_adv = q_new_actions
            else:
                q_adv = qf1_new_actions
        elif self.buffer_policy_sample_actions:
            buf_dist = self.buffer_policy(obs)
            u, _ = buf_dist.rsample_and_logprob()
            qf1_buffer_actions = self.qf1(obs, u)
            qf2_buffer_actions = self.qf2(obs, u)
            q_buffer_actions = torch.min(
                qf1_buffer_actions,
                qf2_buffer_actions,
            )
            if self.awr_min_q:
                q_adv = q_buffer_actions
            else:
                q_adv = qf1_buffer_actions
        else:
            u = actions
            if self.awr_min_q:
                q_adv = torch.min(q1_pred, q2_pred)
            else:
                q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)

        if self.normalize_over_state == "advantage":
            score = q_adv - v_pi
            if self.mask_positive_advantage:
                score = torch.sign(score)
        elif self.normalize_over_state == "Z":
            buffer_dist = self.buffer_policy(obs)
            K = self.Z_K
            buffer_obs = []
            buffer_actions = []
            log_bs = []
            log_pis = []
            for i in range(K):
                u = buffer_dist.sample()
                log_b = buffer_dist.log_prob(u)
                log_pi = dist.log_prob(u)
                buffer_obs.append(obs)
                buffer_actions.append(u)
                log_bs.append(log_b)
                log_pis.append(log_pi)
            buffer_obs = torch.cat(buffer_obs, 0)
            buffer_actions = torch.cat(buffer_actions, 0)
            p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, ))
            log_pi = torch.cat(log_pis, 0)
            log_pi = log_pi.sum(dim=1, )
            q1_b = self.qf1(buffer_obs, buffer_actions)
            q2_b = self.qf2(buffer_obs, buffer_actions)
            q_b = torch.min(q1_b, q2_b)
            q_b = torch.reshape(q_b, (-1, K))
            adv_b = q_b - v_pi
            # if self._n_train_steps_total % 100 == 0:
            #     import ipdb; ipdb.set_trace()
            # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True)
            # score = torch.exp((q_adv - v_pi) / beta) / Z
            # score = score / sum(score)
            logK = torch.log(ptu.tensor(float(K)))
            logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True)
            logS = (q_adv - v_pi) / beta - logZ
            # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True)
            # logS = q_adv/beta - logZ
            score = F.softmax(logS, dim=0)  # score / sum(score)
        else:
            error

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        if self.weight_loss and weights is None:
            if self.normalize_over_batch:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif not self.normalize_over_batch:
                weights = score
            else:
                error
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        if self.compute_bc:
            train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            policy_loss = policy_loss + self.bc_weight * train_policy_loss

        if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0:
            del self.buffer_policy_optimizer
            self.buffer_policy_optimizer = self.optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=self.policy_weight_decay,
                lr=self.policy_lr,
            )
            self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer
            for i in range(self.num_buffer_policy_train_steps_on_reset):
                if self.train_bc_on_rl_buffer:
                    if self.advantage_weighted_buffer_loss:
                        buffer_dist = self.buffer_policy(obs)
                        buffer_u = actions
                        buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob(
                        )
                        buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                        buffer_policy_logpp = buffer_policy_logpp[:, None]

                        buffer_q1_pred = self.qf1(obs, buffer_u)
                        buffer_q2_pred = self.qf2(obs, buffer_u)
                        buffer_q_adv = torch.min(buffer_q1_pred,
                                                 buffer_q2_pred)

                        buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                        buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                        buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                        buffer_score = buffer_q_adv - buffer_v_pi
                        buffer_weights = F.softmax(buffer_score / beta, dim=0)
                        buffer_policy_loss = self.awr_weight * (
                            -buffer_policy_logpp * len(buffer_weights) *
                            buffer_weights.detach()).mean()
                    else:
                        buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                            self.replay_buffer.train_replay_buffer,
                            self.buffer_policy)

                    self.buffer_policy_optimizer.zero_grad()
                    buffer_policy_loss.backward(retain_graph=True)
                    self.buffer_policy_optimizer.step()

        if self.train_bc_on_rl_buffer:
            if self.advantage_weighted_buffer_loss:
                buffer_dist = self.buffer_policy(obs)
                buffer_u = actions
                buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob()
                buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                buffer_policy_logpp = buffer_policy_logpp[:, None]

                buffer_q1_pred = self.qf1(obs, buffer_u)
                buffer_q2_pred = self.qf2(obs, buffer_u)
                buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred)

                buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                buffer_score = buffer_q_adv - buffer_v_pi
                buffer_weights = F.softmax(buffer_score / beta, dim=0)
                buffer_policy_loss = self.awr_weight * (
                    -buffer_policy_logpp * len(buffer_weights) *
                    buffer_weights.detach()).mean()
            else:
                buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0:
            self.buffer_policy_optimizer.zero_grad()
            buffer_policy_loss.backward()
            self.buffer_policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))

            if self.normalize_over_state == "Z":
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'logZ',
                        ptu.get_numpy(logZ),
                    ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.compute_bc:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.policy)
                self.eval_statistics.update({
                    "bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                })
            if self.train_bc_on_rl_buffer:
                _, buffer_train_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)

                _, buffer_test_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.validation_replay_buffer,
                    self.buffer_policy)
                buffer_dist = self.buffer_policy(obs)
                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)

                _, train_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_train_buffer, self.buffer_policy)

                _, test_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.buffer_policy)

                self.eval_statistics.update({
                    "buffer_policy/Train Online Logprob":
                    -1 * ptu.get_numpy(buffer_train_logp_loss),
                    "buffer_policy/Test Online Logprob":
                    -1 * ptu.get_numpy(buffer_test_logp_loss),
                    "buffer_policy/Train Offline Logprob":
                    -1 * ptu.get_numpy(train_offline_logp_loss),
                    "buffer_policy/Test Offline Logprob":
                    -1 * ptu.get_numpy(test_offline_logp_loss),
                    "buffer_policy/train_policy_loss":
                    ptu.get_numpy(buffer_policy_loss),
                    # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss),
                    "buffer_policy/kl_div":
                    ptu.get_numpy(kldiv.mean()),
                })
            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

            if self.validation_qlearning:
                train_data = self.replay_buffer.validation_replay_buffer.random_batch(
                    self.bc_batch_size)
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                # goals = train_data['resampled_goals']
                train_data[
                    'observations'] = obs  # torch.cat((obs, goals), dim=1)
                train_data[
                    'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
                self.test_from_torch(train_data)

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        stats = super().get_diagnostics()
        stats.update(self.eval_statistics)
        return stats

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        nets = [
            self.policy,
            self.qf1,
            self.qf2,
            self.target_qf1,
            self.target_qf2,
        ]
        if self.buffer_policy:
            nets.append(self.buffer_policy)
        return nets

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            qf1=self.qf1,
            qf2=self.qf2,
            target_qf1=self.qf1,
            target_qf2=self.qf2,
            buffer_policy=self.buffer_policy,
        )
Exemplo n.º 2
0
    def __init__(
        self,
        env,
        policy,
        qf1,
        qf2,
        target_qf1,
        target_qf2,
        buffer_policy=None,
        discount=0.99,
        reward_scale=1.0,
        beta=1.0,
        beta_schedule_kwargs=None,
        policy_lr=1e-3,
        qf_lr=1e-3,
        policy_weight_decay=0,
        q_weight_decay=0,
        optimizer_class=optim.Adam,
        soft_target_tau=1e-2,
        target_update_period=1,
        plotter=None,
        render_eval_paths=False,
        use_automatic_entropy_tuning=True,
        target_entropy=None,
        bc_num_pretrain_steps=0,
        q_num_pretrain1_steps=0,
        q_num_pretrain2_steps=0,
        bc_batch_size=128,
        alpha=1.0,
        policy_update_period=1,
        q_update_period=1,
        weight_loss=True,
        compute_bc=True,
        use_awr_update=True,
        use_reparam_update=False,
        bc_weight=0.0,
        rl_weight=1.0,
        reparam_weight=1.0,
        awr_weight=1.0,
        post_pretrain_hyperparams=None,
        post_bc_pretrain_hyperparams=None,
        awr_use_mle_for_vf=False,
        vf_K=1,
        awr_sample_actions=False,
        buffer_policy_sample_actions=False,
        awr_min_q=False,
        brac=False,
        reward_transform_class=None,
        reward_transform_kwargs=None,
        terminal_transform_class=None,
        terminal_transform_kwargs=None,
        pretraining_logging_period=1000,
        train_bc_on_rl_buffer=False,
        use_automatic_beta_tuning=False,
        beta_epsilon=1e-10,
        normalize_over_batch=True,
        normalize_over_state="advantage",
        Z_K=10,
        clip_score=None,
        validation_qlearning=False,
        mask_positive_advantage=False,
        buffer_policy_reset_period=-1,
        num_buffer_policy_train_steps_on_reset=100,
        advantage_weighted_buffer_loss=True,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.buffer_policy = buffer_policy
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_awr_update = use_awr_update
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(
                    self.env.action_space.shape).item(
                    )  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.awr_use_mle_for_vf = awr_use_mle_for_vf
        self.vf_K = vf_K
        self.awr_sample_actions = awr_sample_actions
        self.awr_min_q = awr_min_q

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.optimizers = {}

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            weight_decay=policy_weight_decay,
            lr=policy_lr,
        )
        self.optimizers[self.policy] = self.policy_optimizer
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )

        if buffer_policy and train_bc_on_rl_buffer:
            self.buffer_policy_optimizer = optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=policy_weight_decay,
                lr=policy_lr,
            )
            self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer
            self.optimizer_class = optimizer_class
            self.policy_weight_decay = policy_weight_decay
            self.policy_lr = policy_lr

        self.use_automatic_beta_tuning = use_automatic_beta_tuning and buffer_policy and train_bc_on_rl_buffer
        self.beta_epsilon = beta_epsilon
        if self.use_automatic_beta_tuning:
            self.log_beta = ptu.zeros(1, requires_grad=True)
            self.beta_optimizer = optimizer_class(
                [self.log_beta],
                lr=policy_lr,
            )
        else:
            self.beta = beta
            self.beta_schedule_kwargs = beta_schedule_kwargs
            if beta_schedule_kwargs is None:
                self.beta_schedule = ConstantSchedule(beta)
            else:
                schedule_class = beta_schedule_kwargs.pop(
                    "schedule_class", PiecewiseLinearSchedule)
                self.beta_schedule = schedule_class(**beta_schedule_kwargs)

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.bc_num_pretrain_steps = bc_num_pretrain_steps
        self.q_num_pretrain1_steps = q_num_pretrain1_steps
        self.q_num_pretrain2_steps = q_num_pretrain2_steps
        self.bc_batch_size = bc_batch_size
        self.rl_weight = rl_weight
        self.bc_weight = bc_weight
        self.eval_policy = MakeDeterministic(self.policy)
        self.compute_bc = compute_bc
        self.alpha = alpha
        self.q_update_period = q_update_period
        self.policy_update_period = policy_update_period
        self.weight_loss = weight_loss

        self.reparam_weight = reparam_weight
        self.awr_weight = awr_weight
        self.post_pretrain_hyperparams = post_pretrain_hyperparams
        self.post_bc_pretrain_hyperparams = post_bc_pretrain_hyperparams
        self.update_policy = True
        self.pretraining_logging_period = pretraining_logging_period
        self.normalize_over_batch = normalize_over_batch
        self.normalize_over_state = normalize_over_state
        self.Z_K = Z_K

        self.reward_transform_class = reward_transform_class or LinearTransform
        self.reward_transform_kwargs = reward_transform_kwargs or dict(m=1,
                                                                       b=0)
        self.terminal_transform_class = terminal_transform_class or LinearTransform
        self.terminal_transform_kwargs = terminal_transform_kwargs or dict(m=1,
                                                                           b=0)
        self.reward_transform = self.reward_transform_class(
            **self.reward_transform_kwargs)
        self.terminal_transform = self.terminal_transform_class(
            **self.terminal_transform_kwargs)
        self.use_reparam_update = use_reparam_update
        self.clip_score = clip_score
        self.buffer_policy_sample_actions = buffer_policy_sample_actions

        self.train_bc_on_rl_buffer = train_bc_on_rl_buffer and buffer_policy
        self.validation_qlearning = validation_qlearning
        self.brac = brac
        self.mask_positive_advantage = mask_positive_advantage
        self.buffer_policy_reset_period = buffer_policy_reset_period
        self.num_buffer_policy_train_steps_on_reset = num_buffer_policy_train_steps_on_reset
        self.advantage_weighted_buffer_loss = advantage_weighted_buffer_loss
Exemplo n.º 3
0
class ConvLSTMTrainer(object):
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        positive_range=2,
        negative_range=10,
        triplet_sample_num=8,
        triplet_loss_margin=0.5,
        batch_size=128,
        log_interval=0,
        recon_loss_coef=1,
        triplet_loss_coef=[],
        triplet_loss_type=[],
        ae_loss_coef=1,
        matching_loss_coef=1,
        vae_matching_loss_coef=1,
        matching_loss_one_side=False,
        contrastive_loss_coef=0,
        lstm_kl_loss_coef=0,
        adaptive_margin=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
    ):

        print("In LSTM trainer, ae_loss_coef is: ", ae_loss_coef)
        print("In LSTM trainer, matching_loss_coef is: ", matching_loss_coef)
        print("In LSTM trainer, vae_matching_loss_coef is: ",
              vae_matching_loss_coef)

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot

        self.recon_loss_coef = recon_loss_coef
        self.triplet_loss_coef = triplet_loss_coef
        self.ae_loss_coef = ae_loss_coef
        self.matching_loss_coef = matching_loss_coef
        self.vae_matching_loss_coef = vae_matching_loss_coef
        self.contrastive_loss_coef = contrastive_loss_coef
        self.lstm_kl_loss_coef = lstm_kl_loss_coef
        self.matching_loss_one_side = matching_loss_one_side

        # triplet loss range
        self.positve_range = positive_range
        self.negative_range = negative_range
        self.triplet_sample_num = triplet_sample_num
        self.triplet_loss_margin = triplet_loss_margin
        self.triplet_loss_type = triplet_loss_type
        self.adaptive_margin = adaptive_margin

        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
        else:
            self._train_weights = None

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.eval_statistics = OrderedDict()
        self._extra_stats_to_log = None

    def get_dataset_stats(self, data):
        torch_input = ptu.from_numpy(normalize_image(data))
        mus, log_vars = self.model.encode(torch_input)
        mus = ptu.get_numpy(mus)
        mean = np.mean(mus, axis=0)
        std = np.std(mus, axis=0)
        return mus, mean, std

    def update_train_weights(self):
        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
            if self.use_parallel_dataloading:
                self.train_dataloader = DataLoader(
                    self.train_dataset_pt,
                    sampler=InfiniteWeightedRandomSampler(
                        self.train_dataset, self._train_weights),
                    batch_size=self.batch_size,
                    drop_last=False,
                    num_workers=self.train_data_workers,
                    pin_memory=True,
                )
                self.train_dataloader = iter(self.train_dataloader)

    def _compute_train_weights(self):
        method = self.skew_config.get('method', 'squared_error')
        power = self.skew_config.get('power', 1)
        batch_size = 512
        size = self.train_dataset.shape[0]
        next_idx = min(batch_size, size)
        cur_idx = 0
        weights = np.zeros(size)
        while cur_idx < self.train_dataset.shape[0]:
            idxs = np.arange(cur_idx, next_idx)
            data = self.train_dataset[idxs, :]
            if method == 'vae_prob':
                data = normalize_image(data)
                weights[idxs] = compute_p_x_np_to_np(
                    self.model,
                    data,
                    power=power,
                    **self.priority_function_kwargs)
            else:
                raise NotImplementedError(
                    'Method {} not supported'.format(method))
            cur_idx = next_idx
            next_idx += batch_size
            next_idx = min(next_idx, size)

        if method == 'vae_prob':
            weights = relative_probs_from_log_probs(weights)
        return weights

    def set_vae(self, vae):
        self.model = vae
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

    def get_batch(self, train=True, epoch=None):
        if self.use_parallel_dataloading:
            if not train:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader).to(ptu.device)
            return samples

        dataset = self.train_dataset if train else self.test_dataset
        skew = False
        if epoch is not None:
            skew = (self.start_skew_epoch < epoch)
        if train and self.skew_dataset and skew:
            probs = self._train_weights / np.sum(self._train_weights)
            ind = np.random.choice(
                len(probs),
                self.batch_size,
                p=probs,
            )
        else:
            ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(
            dataset[ind, :])  # this should be a batch of trajectories
        if self.normalize:
            samples = ((samples - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            samples = samples - self.train_data_mean

        samples = np.swapaxes(
            samples, 0, 1)  # turn to trajectory, batch_size, feature_size
        return ptu.from_numpy(samples)

    def get_debug_batch(self, train=True):
        dataset = self.train_dataset if train else self.test_dataset
        X, Y = dataset
        ind = np.random.randint(0, Y.shape[0], self.batch_size)
        X = X[ind, :]
        Y = Y[ind, :]
        return ptu.from_numpy(X), ptu.from_numpy(Y)

    def matching_loss_vae(self, traj_torch):
        _, _, vae_latent_distribution_params, _ = self.model(traj_torch)
        vae_latents_ori = vae_latent_distribution_params[0]

        # new way of correct masking
        masked_traj = traj_torch.detach().clone()
        traj_len, batch_size, imlen = masked_traj.shape
        mask = (np.random.uniform(size=(self.imsize, self.imsize)) >
                0.5).astype(np.float)
        mask = np.stack([mask, mask, mask], axis=0).flatten()
        masked_traj = masked_traj * ptu.from_numpy(
            mask)  # mask all images for training vae latent space

        _, _, vae_latent_distribution_params, _ = self.model(masked_traj)
        vae_latents_masked = vae_latent_distribution_params[0]

        if self.matching_loss_one_side:
            loss = F.mse_loss(vae_latents_ori.detach(), vae_latents_masked)
        else:
            loss = F.mse_loss(vae_latents_ori, vae_latents_masked)

        return loss

    def contrastive_loss(self, traj):
        masked_traj = traj.detach().clone()
        traj_len, batch_size, imlen = traj.shape
        masked_idx = np.random.randint(low=10, high=traj_len, size=batch_size)
        for i in range(batch_size):
            mask = (np.random.uniform(size=(self.imsize, self.imsize)) >
                    0.5).astype(np.float)
            mask = np.stack([mask, mask, mask], axis=0).flatten()
            masked_traj[masked_idx[i]][i] *= ptu.from_numpy(mask)

        latents_ori = self.model.encode(traj)[0]
        latents_masked = self.model.encode(masked_traj)[0]

        loss = ptu.zeros(1)
        for j in range(batch_size):
            latents_ori_vec = latents_ori[masked_idx[j]:, j]
            latents_masked_vec = latents_masked[masked_idx[j]:, j]
            loss += F.mse_loss(latents_masked_vec, latents_ori_vec)

        postive_loss = loss / batch_size

        encodings = latents_ori
        seq_len, batch_size, _ = encodings.shape

        anchors, negatives, margins = [], [], []

        for t in range(seq_len):
            neg_range_prev_end, neg_range_after_beg = max(
                0, t - self.negative_range), min(seq_len - 1,
                                                 t + self.negative_range)
            for _ in range(self.triplet_sample_num):
                neg_idices = np.array(
                    [x for x in range(neg_range_prev_end)] +
                    [x for x in range(neg_range_after_beg + 1, seq_len)],
                    dtype=np.int32)
                neg_idx = np.random.randint(0, len(neg_idices), batch_size)
                neg_idx = neg_idices[neg_idx]  # batch_size

                if self.adaptive_margin > 0:
                    time_differences = np.abs(neg_idx - t)  # batch_size
                    adaptive_margins = self.adaptive_margin * time_differences
                    margins.append(ptu.from_numpy(adaptive_margins))
                else:
                    margins.append(
                        ptu.from_numpy(
                            np.array([
                                self.triplet_loss_margin
                                for _ in range(batch_size)
                            ])))

                anchor_samples = encodings[t]  # batch_size, feature_size
                negative_samples = encodings[neg_idx, np.arange(batch_size)]

                anchors.append(anchor_samples)
                negatives.append(negative_samples)

        anchors = torch.cat(anchors, dim=0)
        negatives = torch.cat(negatives, dim=0)
        margins = torch.cat(margins)

        negative_distances = (anchors - negatives).pow(2).sum(dim=1)
        losses = F.relu(margins - negative_distances, 0)
        negative_loss = losses.mean()

        return postive_loss + negative_loss

    def matching_loss(self, traj_torch):
        masked_traj = traj_torch.detach().clone()
        traj_len, batch_size, imlen = traj_torch.shape
        masked_idx = np.random.randint(low=10, high=traj_len, size=batch_size)
        for i in range(batch_size):
            mask = (np.random.uniform(size=(self.imsize, self.imsize)) >
                    0.5).astype(np.float)
            mask = np.stack([mask, mask, mask], axis=0).flatten()

            masked_traj[masked_idx[i]][i] *= ptu.from_numpy(mask)

        latents_ori = self.model.encode(traj_torch)[0]
        latents_masked = self.model.encode(masked_traj)[0]

        loss = ptu.zeros(1)
        for j in range(batch_size):
            latents_ori_vec = latents_ori[masked_idx[j]:, j]
            latents_masked_vec = latents_masked[masked_idx[j]:, j]
            if self.matching_loss_one_side:
                loss += F.mse_loss(latents_masked_vec,
                                   latents_ori_vec.detach())
            else:
                loss += F.mse_loss(latents_masked_vec, latents_ori_vec)
        return loss / batch_size

    def triplet_loss_3(self, traj_torch):
        warm_len = 10
        seq_len, batch_size, imlen = traj_torch.shape
        traj = traj_torch.clone().detach()
        traj = traj[:, :batch_size // 2, :]
        seq_len, batch_size, imlen = traj.shape
        anchors, positives, negatives = [], [], []

        for t in range(seq_len):
            # print(t)
            neg_range_prev_end, neg_range_after_beg = max(
                0, t - self.negative_range), min(seq_len - 1,
                                                 t + self.negative_range)
            for _ in range(self.triplet_sample_num):
                neg_idices = np.array(
                    [x for x in range(neg_range_prev_end)] +
                    [x for x in range(neg_range_after_beg + 1, seq_len)],
                    dtype=np.int32)
                neg_idx = np.random.randint(0, len(neg_idices), batch_size)
                neg_idx = neg_idices[neg_idx]

                # get the anchor encodings
                anchor_traj = ptu.zeros((warm_len, batch_size, imlen))
                if t - warm_len + 1 >= 0:
                    anchor_traj = traj[t - warm_len + 1:t + 1, :, :]
                else:
                    anchor_traj[-(t + 1):] = traj[:t + 1]
                    for _ in range(warm_len - t - 1):
                        anchor_traj[_] = traj[0]
                anchor_encodings = self.model.encode(anchor_traj)[
                    0]  # always assmue we use the mean as encoding

                # get the positive encodings: mask out part of the anchor samples
                pos_traj = anchor_traj.clone().detach()
                mask = (np.random.uniform(size=(self.imsize, self.imsize)) >
                        0.5).astype(np.float)
                mask = np.stack([mask, mask, mask], axis=0).flatten()
                pos_traj[-1, :, :] *= ptu.from_numpy(mask)
                pos_encodings = self.model.encode(pos_traj)[0]

                # get the negative encodings
                neg_traj = ptu.zeros((warm_len, batch_size, imlen))
                for b_idx in range(batch_size):
                    n_idx = neg_idx[b_idx]
                    if n_idx - warm_len + 1 >= 0:
                        neg_traj[:, b_idx, :] = traj[n_idx - warm_len +
                                                     1:n_idx + 1, b_idx, :]
                    else:
                        neg_traj[-(n_idx + 1):, b_idx, :] = traj[:n_idx + 1,
                                                                 b_idx, :]
                        for _ in range(warm_len - n_idx - 1):
                            neg_traj[_, b_idx, :] = traj[0, b_idx, :]
                neg_encodings = self.model.encode(neg_traj)[0]

                anchors.append(anchor_encodings)
                positives.append(pos_encodings)
                negatives.append(neg_encodings)

        anchors = torch.cat(anchors, dim=0)
        positives = torch.cat(positives, dim=0)
        negatives = torch.cat(negatives, dim=0)

        positive_distances = (anchors - positives).pow(2).sum(dim=1)
        negative_distances = (anchors - negatives).pow(2).sum(dim=1)
        losses = F.relu(
            positive_distances - negative_distances + self.triplet_loss_margin,
            0)
        return losses.mean()

    def triplet_loss_2(self, traj_torch):
        '''
        use the same len of images to warm up the lstm encoding.
        '''
        warm_len = 10
        seq_len, batch_size, imlen = traj_torch.shape
        # traj = traj_torch.clone().detach()
        traj = traj_torch
        seq_len, batch_size, imlen = traj.shape
        anchors, positives, negatives = [], [], []

        for t in range(seq_len):
            # print(t)
            pos_range_beg, pos_range_end = max(0, t - self.positve_range), min(
                seq_len - 1, t + self.positve_range)
            neg_range_prev_end, neg_range_after_beg = max(
                0, t - self.negative_range), min(seq_len - 1,
                                                 t + self.negative_range)
            for _ in range(self.triplet_sample_num):
                pos_indices = np.array(
                    [x for x in range(pos_range_beg, t)] +
                    [x for x in range(t + 1, pos_range_end + 1)],
                    dtype=np.int32)
                pos_idx = np.random.randint(0, len(pos_indices), batch_size)
                pos_idx = pos_indices[pos_idx]

                neg_idices = np.array(
                    [x for x in range(neg_range_prev_end)] +
                    [x for x in range(neg_range_after_beg + 1, seq_len)],
                    dtype=np.int32)
                neg_idx = np.random.randint(0, len(neg_idices), batch_size)
                neg_idx = neg_idices[neg_idx]

                # get the anchor encodings
                anchor_traj = ptu.zeros((warm_len, batch_size, imlen))
                if t - warm_len + 1 >= 0:
                    anchor_traj = traj[t - warm_len + 1:t + 1, :, :]
                else:
                    anchor_traj[-(t + 1):] = traj[:t + 1]
                    for _ in range(warm_len - t - 1):
                        anchor_traj[_] = traj[0]
                anchor_encodings = self.model.encode(anchor_traj)[
                    0]  # always assmue we use the mean as encoding

                # get the positive encodings
                pos_traj = ptu.zeros((warm_len, batch_size, imlen))
                for b_idx in range(batch_size):
                    p_idx = pos_idx[b_idx]
                    if p_idx - warm_len + 1 >= 0:
                        pos_traj[:, b_idx, :] = traj[p_idx - warm_len +
                                                     1:p_idx + 1, b_idx, :]
                    else:
                        pos_traj[-(p_idx + 1):, b_idx, :] = traj[:p_idx + 1,
                                                                 b_idx, :]
                        for _ in range(warm_len - p_idx - 1):
                            pos_traj[_, b_idx, :] = traj[0, b_idx, :]
                pos_encodings = self.model.encode(pos_traj)[0]

                # get the negative encodings
                neg_traj = ptu.zeros((warm_len, batch_size, imlen))
                for b_idx in range(batch_size):
                    n_idx = neg_idx[b_idx]
                    if n_idx - warm_len + 1 >= 0:
                        neg_traj[:, b_idx, :] = traj[n_idx - warm_len +
                                                     1:n_idx + 1, b_idx, :]
                    else:
                        neg_traj[-(n_idx + 1):, b_idx, :] = traj[:n_idx + 1,
                                                                 b_idx, :]
                        for _ in range(warm_len - n_idx - 1):
                            neg_traj[_, b_idx, :] = traj[0, b_idx, :]
                neg_encodings = self.model.encode(neg_traj)[0]

                anchors.append(anchor_encodings)
                positives.append(pos_encodings)
                negatives.append(neg_encodings)

        anchors = torch.cat(anchors, dim=0)
        positives = torch.cat(positives, dim=0)
        negatives = torch.cat(negatives, dim=0)

        positive_distances = (anchors - positives).pow(2).sum(dim=1)
        negative_distances = (anchors - negatives).pow(2).sum(dim=1)
        losses = F.relu(
            positive_distances - negative_distances + self.triplet_loss_margin,
            0)
        return losses.mean()

    def triplet_loss(self, encodings):
        '''
        encodings: [seq_len, batch_size, feature_size]
        '''
        seq_len, batch_size, feature_size = encodings.shape

        anchors, positives, negatives = [], [], []

        for t in range(seq_len):
            pos_range_beg, pos_range_end = max(0, t - self.positve_range), min(
                seq_len - 1, t + self.positve_range)
            neg_range_prev_end, neg_range_after_beg = max(
                0, t - self.negative_range), min(seq_len - 1,
                                                 t + self.negative_range)
            for _ in range(self.triplet_sample_num):
                pos_indices = np.array(
                    [x for x in range(pos_range_beg, t)] +
                    [x for x in range(t + 1, pos_range_end + 1)],
                    dtype=np.int32)
                pos_idx = np.random.randint(0, len(pos_indices), batch_size)
                pos_idx = pos_indices[pos_idx]

                neg_idices = np.array(
                    [x for x in range(neg_range_prev_end)] +
                    [x for x in range(neg_range_after_beg + 1, seq_len)],
                    dtype=np.int32)
                neg_idx = np.random.randint(0, len(neg_idices), batch_size)
                neg_idx = neg_idices[neg_idx]

                anchor_samples = encodings[t]  # batch_size, feature_size
                positive_samples = encodings[pos_idx, np.arange(batch_size)]
                negative_samples = encodings[neg_idx, np.arange(batch_size)]

                anchors.append(anchor_samples)
                positives.append(positive_samples)
                negatives.append(negative_samples)

        anchors = torch.cat(anchors, dim=0)
        positives = torch.cat(positives, dim=0)
        negatives = torch.cat(negatives, dim=0)

        positive_distances = (anchors - positives).pow(2).sum(dim=1)
        negative_distances = (anchors - negatives).pow(2).sum(dim=1)
        losses = F.relu(
            positive_distances - negative_distances + self.triplet_loss_margin,
            0)
        return losses.mean()

    def train_epoch(self,
                    epoch,
                    sample_batch=None,
                    batches=25,
                    from_rl=False,
                    key=None,
                    only_train_vae=False):
        self.model.train()
        losses = []
        log_probs = []
        triplet_losses = []
        kles = []
        ae_losses = []
        matching_losses = []
        vae_matching_losses = []
        contrastive_losses = []
        lstm_kles = []
        # zs = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(batches):
            if sample_batch is not None:
                data = sample_batch(self.batch_size, key=key)
                next_obs = data['next_obs']
            else:
                next_obs = self.get_batch(epoch=epoch)

            self.optimizer.zero_grad()

            reconstructions, obs_distribution_params, vae_latent_distribution_params, lstm_latent_encodings = self.model(
                next_obs)
            latent_encodings = lstm_latent_encodings
            vae_mu = vae_latent_distribution_params[0]
            latent_distribution_params = vae_latent_distribution_params

            triplet_loss = ptu.zeros(1)
            for tri_idx, triplet_type in enumerate(self.triplet_loss_type):
                if triplet_type == 1 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss(latent_encodings)
                elif triplet_type == 2 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss_2(next_obs)
                elif triplet_type == 3 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss_3(next_obs)

            if self.matching_loss_coef > 0 and not only_train_vae:
                matching_loss = self.matching_loss(next_obs)
            else:
                matching_loss = ptu.zeros(1)

            if self.vae_matching_loss_coef > 0:
                matching_loss_vae = self.matching_loss_vae(next_obs)
            else:
                matching_loss_vae = ptu.zeros(1)

            if self.contrastive_loss_coef > 0 and not only_train_vae:
                contrastive_loss = self.contrastive_loss(next_obs)
            else:
                contrastive_loss = ptu.zeros(1)

            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)
            lstm_kle = ptu.zeros(1)

            ae_loss = F.mse_loss(
                latent_encodings.view((-1, self.model.representation_size)),
                vae_mu.detach())
            ae_losses.append(ae_loss.item())

            loss = -self.recon_loss_coef * log_prob + \
                    beta * kle + \
                    self.matching_loss_coef * matching_loss + \
                    self.ae_loss_coef * ae_loss + \
                    triplet_loss + \
                    self.vae_matching_loss_coef * matching_loss_vae + \
                    self.contrastive_loss_coef * contrastive_loss

            self.optimizer.zero_grad()
            loss.backward()
            losses.append(loss.item())
            log_probs.append(log_prob.item())
            kles.append(kle.item())
            lstm_kles.append(lstm_kle.item())
            triplet_losses.append(triplet_loss.item())
            matching_losses.append(matching_loss.item())
            vae_matching_losses.append(matching_loss_vae.item())
            contrastive_losses.append(contrastive_loss.item())
            self.optimizer.step()

            if self.log_interval and batch_idx % self.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data),
                    len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader),
                    loss.item() / len(next_obs)))

            # dump a batch of training images for debugging
            # if batch_idx == 0 and epoch % 25 == 0:
            #     n = min(next_obs.size(0), 8)
            #     comparison = torch.cat([
            #         next_obs[:n].narrow(start=0, length=self.imlength, dim=1)
            #             .contiguous().view(
            #             -1, self.input_channels, self.imsize, self.imsize
            #         ).transpose(2, 3),
            #         reconstructions.view(
            #             self.batch_size,
            #             self.input_channels,
            #             self.imsize,
            #             self.imsize,
            #         )[:n].transpose(2, 3)
            #     ])
            #     save_dir = osp.join(logger.get_snapshot_dir(),
            #                         'vae_train_{}_{}.png'.format(key, epoch))
            #     save_image(comparison.data.cpu(), save_dir, nrow=n)

        # if not from_rl:
        #     zs = np.array(zs)
        #     self.model.dist_mu = zs.mean(axis=0)
        #     self.model.dist_std = zs.std(axis=0)

        self.eval_statistics['train/log prob'] = np.mean(log_probs)
        self.eval_statistics['train/triplet loss'] = np.mean(triplet_losses)
        self.eval_statistics['train/matching loss'] = np.mean(matching_losses)
        self.eval_statistics['train/vae matching loss'] = np.mean(
            vae_matching_losses)
        self.eval_statistics['train/KL'] = np.mean(kles)
        self.eval_statistics['train/lstm KL'] = np.mean(lstm_kles)
        self.eval_statistics['train/loss'] = np.mean(losses)
        self.eval_statistics['train/contrastive loss'] = np.mean(
            contrastive_losses)
        self.eval_statistics['train/ae loss'] = np.mean(ae_losses)

        torch.cuda.empty_cache()

    def get_diagnostics(self):
        return self.eval_statistics

    def test_epoch(
        self,
        epoch,
        sample_batch=None,
        key=None,
        save_reconstruction=True,
        save_vae=True,
        from_rl=False,
        save_prefix='r',
        only_train_vae=False,
    ):
        self.model.eval()
        losses = []
        log_probs = []
        triplet_losses = []
        matching_losses = []
        vae_matching_losses = []
        kles = []
        lstm_kles = []
        ae_losses = []
        contrastive_losses = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(10):
            # print(batch_idx)
            if sample_batch is not None:
                data = sample_batch(self.batch_size, key=key)
                next_obs = data['next_obs']
            else:
                next_obs = self.get_batch(epoch=epoch)

            reconstructions, obs_distribution_params, vae_latent_distribution_params, lstm_latent_encodings = self.model(
                next_obs)
            latent_encodings = lstm_latent_encodings
            vae_mu = vae_latent_distribution_params[0]  # this is lstm inputs
            latent_distribution_params = vae_latent_distribution_params

            triplet_loss = ptu.zeros(1)
            for tri_idx, triplet_type in enumerate(self.triplet_loss_type):
                if triplet_type == 1 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss(latent_encodings)
                elif triplet_type == 2 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss_2(next_obs)
                elif triplet_type == 3 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss_3(next_obs)

            if self.matching_loss_coef > 0 and not only_train_vae:
                matching_loss = self.matching_loss(next_obs)
            else:
                matching_loss = ptu.zeros(1)

            if self.vae_matching_loss_coef > 0:
                matching_loss_vae = self.matching_loss_vae(next_obs)
            else:
                matching_loss_vae = ptu.zeros(1)

            if self.contrastive_loss_coef > 0 and not only_train_vae:
                contrastive_loss = self.contrastive_loss(next_obs)
            else:
                contrastive_loss = ptu.zeros(1)

            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)
            lstm_kle = ptu.zeros(1)

            ae_loss = F.mse_loss(
                latent_encodings.view((-1, self.model.representation_size)),
                vae_mu.detach())
            ae_losses.append(ae_loss.item())

            loss = -self.recon_loss_coef * log_prob + beta * kle + \
                        self.matching_loss_coef * matching_loss + self.ae_loss_coef * ae_loss + triplet_loss + \
                            self.vae_matching_loss_coef * matching_loss_vae + self.contrastive_loss_coef * contrastive_loss

            losses.append(loss.item())
            log_probs.append(log_prob.item())
            triplet_losses.append(triplet_loss.item())
            matching_losses.append(matching_loss.item())
            vae_matching_losses.append(matching_loss_vae.item())
            kles.append(kle.item())
            lstm_kles.append(lstm_kle.item())
            contrastive_losses.append(contrastive_loss.item())

            if batch_idx == 0 and save_reconstruction:
                seq_len, batch_size, feature_size = next_obs.shape
                show_obs = next_obs[0][:8]
                reconstructions = reconstructions.view(
                    (seq_len, batch_size, feature_size))[0][:8]
                comparison = torch.cat([
                    show_obs.narrow(start=0, length=self.imlength,
                                    dim=1).contiguous().view(
                                        -1, self.input_channels, self.imsize,
                                        self.imsize).transpose(2, 3),
                    reconstructions.view(
                        -1,
                        self.input_channels,
                        self.imsize,
                        self.imsize,
                    ).transpose(2, 3)
                ])
                save_dir = osp.join(logger.get_snapshot_dir(),
                                    '{}{}.png'.format(save_prefix, epoch))
                save_image(comparison.data.cpu(), save_dir, nrow=8)

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['test/log prob'] = np.mean(log_probs)
        self.eval_statistics['test/triplet loss'] = np.mean(triplet_losses)
        self.eval_statistics['test/vae matching loss'] = np.mean(
            vae_matching_losses)
        self.eval_statistics['test/matching loss'] = np.mean(matching_losses)
        self.eval_statistics['test/KL'] = np.mean(kles)
        self.eval_statistics['test/lstm KL'] = np.mean(lstm_kles)
        self.eval_statistics['test/loss'] = np.mean(losses)
        self.eval_statistics['test/contrastive loss'] = np.mean(
            contrastive_losses)
        self.eval_statistics['beta'] = beta
        self.eval_statistics['test/ae loss'] = np.mean(ae_losses)

        if not from_rl:
            for k, v in self.eval_statistics.items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
            if save_vae:
                logger.save_itr_params(epoch, self.model)

        torch.cuda.empty_cache()

    def debug_statistics(self):
        """
        Given an image $$x$$, samples a bunch of latents from the prior
        $$z_i$$ and decode them $$\hat x_i$$.
        Compare this to $$\hat x$$, the reconstruction of $$x$$.
        Ideally
         - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE
           isn’t ignoring the latent)
         - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for
           coverage)
        """
        debug_batch_size = 64
        data = self.get_batch(train=False)
        reconstructions, _, _ = self.model(data)
        img = data[0]
        recon_mse = ((reconstructions[0] - img)**2).mean().view(-1)
        img_repeated = img.expand((debug_batch_size, img.shape[0]))

        samples = ptu.randn(debug_batch_size, self.representation_size)
        random_imgs, _ = self.model.decode(samples)
        random_mses = (random_imgs - img_repeated)**2
        mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse)
        stats = create_stats_ordered_dict(
            'debug/MSE improvement over random',
            mse_improvement,
        )
        stats.update(
            create_stats_ordered_dict(
                'debug/MSE of random decoding',
                ptu.get_numpy(random_mses),
            ))
        stats['debug/MSE of reconstruction'] = ptu.get_numpy(recon_mse)[0]
        if self.skew_dataset:
            stats.update(
                create_stats_ordered_dict('train weight', self._train_weights))
        return stats

    def dump_samples(self, epoch, save_prefix='s'):
        self.model.eval()
        sample = ptu.randn(64, self.representation_size)
        sample = self.model.decode(sample)[0].cpu()
        save_dir = osp.join(logger.get_snapshot_dir(),
                            '{}{}.png'.format(save_prefix, epoch))
        save_image(
            sample.data.view(64, self.input_channels, self.imsize,
                             self.imsize).transpose(2, 3), save_dir)

    def _dump_imgs_and_reconstructions(self, idxs, filename):
        imgs = []
        recons = []
        for i in idxs:
            img_np = self.train_dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.imsize,
                                 self.imsize).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize,
                              self.imsize).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(logger.get_snapshot_dir(), filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=len(idxs),
        )

    def log_loss_under_uniform(self, model, data, priority_function_kwargs):
        import torch.nn.functional as F
        log_probs_prior = []
        log_probs_biased = []
        log_probs_importance = []
        kles = []
        mses = []
        for i in range(0, data.shape[0], self.batch_size):
            img = normalize_image(data[i:min(data.shape[0], i +
                                             self.batch_size), :])
            torch_img = ptu.from_numpy(img)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                torch_img)

            priority_function_kwargs['sampling_method'] = 'true_prior_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_prior = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'biased_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_biased = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'importance_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_importance = (log_p - log_q + log_d).mean()

            kle = model.kl_divergence(latent_distribution_params)
            mse = F.mse_loss(torch_img,
                             reconstructions,
                             reduction='elementwise_mean')
            mses.append(mse.item())
            kles.append(kle.item())
            log_probs_prior.append(log_prob_prior.item())
            log_probs_biased.append(log_prob_biased.item())
            log_probs_importance.append(log_prob_importance.item())

        logger.record_tabular("Uniform Data Log Prob (True Prior)",
                              np.mean(log_probs_prior))
        logger.record_tabular("Uniform Data Log Prob (Biased)",
                              np.mean(log_probs_biased))
        logger.record_tabular("Uniform Data Log Prob (Importance)",
                              np.mean(log_probs_importance))
        logger.record_tabular("Uniform Data KL", np.mean(kles))
        logger.record_tabular("Uniform Data MSE", np.mean(mses))

    def dump_uniform_imgs_and_reconstructions(self, dataset, epoch):
        idxs = np.random.choice(range(dataset.shape[0]), 4)
        filename = 'uniform{}.png'.format(epoch)
        imgs = []
        recons = []
        for i in idxs:
            img_np = dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.imsize,
                                 self.imsize).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize,
                              self.imsize).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(logger.get_snapshot_dir(), filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=4,
        )
Exemplo n.º 4
0
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        positive_range=2,
        negative_range=10,
        triplet_sample_num=8,
        triplet_loss_margin=0.5,
        batch_size=128,
        log_interval=0,
        recon_loss_coef=1,
        triplet_loss_coef=[],
        triplet_loss_type=[],
        ae_loss_coef=1,
        matching_loss_coef=1,
        vae_matching_loss_coef=1,
        matching_loss_one_side=False,
        contrastive_loss_coef=0,
        lstm_kl_loss_coef=0,
        adaptive_margin=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
    ):

        print("In LSTM trainer, ae_loss_coef is: ", ae_loss_coef)
        print("In LSTM trainer, matching_loss_coef is: ", matching_loss_coef)
        print("In LSTM trainer, vae_matching_loss_coef is: ",
              vae_matching_loss_coef)

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot

        self.recon_loss_coef = recon_loss_coef
        self.triplet_loss_coef = triplet_loss_coef
        self.ae_loss_coef = ae_loss_coef
        self.matching_loss_coef = matching_loss_coef
        self.vae_matching_loss_coef = vae_matching_loss_coef
        self.contrastive_loss_coef = contrastive_loss_coef
        self.lstm_kl_loss_coef = lstm_kl_loss_coef
        self.matching_loss_one_side = matching_loss_one_side

        # triplet loss range
        self.positve_range = positive_range
        self.negative_range = negative_range
        self.triplet_sample_num = triplet_sample_num
        self.triplet_loss_margin = triplet_loss_margin
        self.triplet_loss_type = triplet_loss_type
        self.adaptive_margin = adaptive_margin

        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
        else:
            self._train_weights = None

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.eval_statistics = OrderedDict()
        self._extra_stats_to_log = None
Exemplo n.º 5
0
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        use_parallel_dataloading=True,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
    ):
        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot

        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
        else:
            self._train_weights = None

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.eval_statistics = OrderedDict()
        self._extra_stats_to_log = None
Exemplo n.º 6
0
class ConvVAETrainer(object):
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        use_parallel_dataloading=True,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
    ):
        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot

        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
        else:
            self._train_weights = None

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.eval_statistics = OrderedDict()
        self._extra_stats_to_log = None

    def get_dataset_stats(self, data):
        torch_input = ptu.from_numpy(normalize_image(data))
        mus, log_vars = self.model.encode(torch_input)
        mus = ptu.get_numpy(mus)
        mean = np.mean(mus, axis=0)
        std = np.std(mus, axis=0)
        return mus, mean, std

    def update_train_weights(self):
        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
            if self.use_parallel_dataloading:
                self.train_dataloader = DataLoader(
                    self.train_dataset_pt,
                    sampler=InfiniteWeightedRandomSampler(
                        self.train_dataset, self._train_weights),
                    batch_size=self.batch_size,
                    drop_last=False,
                    num_workers=self.train_data_workers,
                    pin_memory=True,
                )
                self.train_dataloader = iter(self.train_dataloader)

    def _compute_train_weights(self):
        method = self.skew_config.get('method', 'squared_error')
        power = self.skew_config.get('power', 1)
        batch_size = 512
        size = self.train_dataset.shape[0]
        next_idx = min(batch_size, size)
        cur_idx = 0
        weights = np.zeros(size)
        while cur_idx < self.train_dataset.shape[0]:
            idxs = np.arange(cur_idx, next_idx)
            data = self.train_dataset[idxs, :]
            if method == 'vae_prob':
                data = normalize_image(data)
                weights[idxs] = compute_p_x_np_to_np(
                    self.model,
                    data,
                    power=power,
                    **self.priority_function_kwargs)
            else:
                raise NotImplementedError(
                    'Method {} not supported'.format(method))
            cur_idx = next_idx
            next_idx += batch_size
            next_idx = min(next_idx, size)

        if method == 'vae_prob':
            weights = relative_probs_from_log_probs(weights)
        return weights

    def set_vae(self, vae):
        self.model = vae
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

    def get_batch(self, train=True, epoch=None):
        if self.use_parallel_dataloading:
            if not train:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader).to(ptu.device)
            return samples

        dataset = self.train_dataset if train else self.test_dataset
        skew = False
        if epoch is not None:
            skew = (self.start_skew_epoch < epoch)
        if train and self.skew_dataset and skew:
            probs = self._train_weights / np.sum(self._train_weights)
            ind = np.random.choice(
                len(probs),
                self.batch_size,
                p=probs,
            )
        else:
            ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        if self.normalize:
            samples = ((samples - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            samples = samples - self.train_data_mean
        return ptu.from_numpy(samples)

    def get_debug_batch(self, train=True):
        dataset = self.train_dataset if train else self.test_dataset
        X, Y = dataset
        ind = np.random.randint(0, Y.shape[0], self.batch_size)
        X = X[ind, :]
        Y = Y[ind, :]
        return ptu.from_numpy(X), ptu.from_numpy(Y)

    def train_epoch(self,
                    epoch,
                    sample_batch=None,
                    batches=100,
                    from_rl=False):
        self.model.train()
        losses = []
        log_probs = []
        kles = []
        zs = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(batches):
            if sample_batch is not None:
                data = sample_batch(self.batch_size, epoch)
                # obs = data['obs']
                next_obs = data['next_obs']
                # actions = data['actions']
            else:
                next_obs = self.get_batch(epoch=epoch)
                obs = None
                actions = None
            self.optimizer.zero_grad()
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                next_obs)
            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)

            encoder_mean = self.model.get_encoding_from_latent_distribution_params(
                latent_distribution_params)
            z_data = ptu.get_numpy(encoder_mean.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])

            loss = -1 * log_prob + beta * kle

            self.optimizer.zero_grad()
            loss.backward()
            losses.append(loss.item())
            log_probs.append(log_prob.item())
            kles.append(kle.item())

            self.optimizer.step()
            if self.log_interval and batch_idx % self.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data),
                    len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader),
                    loss.item() / len(next_obs)))
        if not from_rl:
            zs = np.array(zs)
            self.model.dist_mu = zs.mean(axis=0)
            self.model.dist_std = zs.std(axis=0)

        self.eval_statistics['train/log prob'] = np.mean(log_probs)
        self.eval_statistics['train/KL'] = np.mean(kles)
        self.eval_statistics['train/loss'] = np.mean(losses)

    def get_diagnostics(self):
        return self.eval_statistics

    def test_epoch(
        self,
        epoch,
        save_reconstruction=True,
        save_vae=True,
        from_rl=False,
    ):
        self.model.eval()
        losses = []
        log_probs = []
        kles = []
        zs = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(10):
            next_obs = self.get_batch(train=False)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                next_obs)
            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)
            loss = -1 * log_prob + beta * kle

            encoder_mean = latent_distribution_params[0]
            z_data = ptu.get_numpy(encoder_mean.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])
            losses.append(loss.item())
            log_probs.append(log_prob.item())
            kles.append(kle.item())

            if batch_idx == 0 and save_reconstruction:
                n = min(next_obs.size(0), 8)
                comparison = torch.cat([
                    next_obs[:n].narrow(start=0, length=self.imlength,
                                        dim=1).contiguous().view(
                                            -1, self.input_channels,
                                            self.imsize,
                                            self.imsize).transpose(2, 3),
                    reconstructions.view(
                        self.batch_size,
                        self.input_channels,
                        self.imsize,
                        self.imsize,
                    )[:n].transpose(2, 3)
                ])
                # test = str(logger.get_snapshot_dir()) + 'r{}.png'.format(epoch)
                # print('epoch number: ', test)
                save_dir = osp.join(
                    '/mnt/manh/project/visual_RL_imaged_goal/experiment_result',
                    'test_r%d.png' % epoch)
                save_image(comparison.data.cpu(), save_dir, nrow=n)

        zs = np.array(zs)

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['test/log prob'] = np.mean(log_probs)
        self.eval_statistics['test/KL'] = np.mean(kles)
        self.eval_statistics['test/loss'] = np.mean(losses)
        self.eval_statistics['beta'] = beta
        if not from_rl:
            for k, v in self.eval_statistics.items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
            if save_vae:
                logger.save_itr_params(epoch, self.model)

    def debug_statistics(self):
        """
        Given an image $$x$$, samples a bunch of latents from the prior
        $$z_i$$ and decode them $$\hat x_i$$.
        Compare this to $$\hat x$$, the reconstruction of $$x$$.
        Ideally
         - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE
           isn’t ignoring the latent)
         - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for
           coverage)
        """
        debug_batch_size = 64
        data = self.get_batch(train=False)
        reconstructions, _, _ = self.model(data)
        img = data[0]
        recon_mse = ((reconstructions[0] - img)**2).mean().view(-1)
        img_repeated = img.expand((debug_batch_size, img.shape[0]))

        samples = ptu.randn(debug_batch_size, self.representation_size)
        random_imgs, _ = self.model.decode(samples)
        random_mses = (random_imgs - img_repeated)**2
        mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse)
        stats = create_stats_ordered_dict(
            'debug/MSE improvement over random',
            mse_improvement,
        )
        stats.update(
            create_stats_ordered_dict(
                'debug/MSE of random decoding',
                ptu.get_numpy(random_mses),
            ))
        stats['debug/MSE of reconstruction'] = ptu.get_numpy(recon_mse)[0]
        if self.skew_dataset:
            stats.update(
                create_stats_ordered_dict('train weight', self._train_weights))
        return stats

    def dump_samples(self, epoch):
        self.model.eval()
        sample = ptu.randn(64, self.representation_size)
        sample = self.model.decode(sample)[0].cpu()
        # save_dir = osp.join(logger.get_snapshot_dir(), 's%d.png' % epoch)
        # save_dir = osp.join('/mnt/manh/project/visual_RL_imaged_goal', 's%d.png' % epoch)
        project_path = osp.abspath(os.curdir)
        save_dir = osp.join(project_path + str('/result_image/'),
                            's%d.png' % epoch)
        save_image(
            sample.data.view(64, self.input_channels, self.imsize,
                             self.imsize).transpose(2, 3), save_dir)

    def _dump_imgs_and_reconstructions(self, idxs, filename):
        imgs = []
        recons = []
        for i in idxs:
            img_np = self.train_dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.imsize,
                                 self.imsize).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize,
                              self.imsize).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        # save_file = osp.join(logger.get_snapshot_dir(), filename)
        #save_file = osp.join('/mnt/manh/project/visual_RL_imaged_goal', filename)
        project_path = osp.abspath(os.curdir)
        save_dir = osp.join(project_path + str('/result_image/'), filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=len(idxs),
        )

    def log_loss_under_uniform(self, model, data, priority_function_kwargs):
        import torch.nn.functional as F
        log_probs_prior = []
        log_probs_biased = []
        log_probs_importance = []
        kles = []
        mses = []
        for i in range(0, data.shape[0], self.batch_size):
            img = normalize_image(data[i:min(data.shape[0], i +
                                             self.batch_size), :])
            torch_img = ptu.from_numpy(img)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                torch_img)

            priority_function_kwargs['sampling_method'] = 'true_prior_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_prior = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'biased_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_biased = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'importance_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_importance = (log_p - log_q + log_d).mean()

            kle = model.kl_divergence(latent_distribution_params)
            mse = F.mse_loss(torch_img,
                             reconstructions,
                             reduction='elementwise_mean')
            mses.append(mse.item())
            kles.append(kle.item())
            log_probs_prior.append(log_prob_prior.item())
            log_probs_biased.append(log_prob_biased.item())
            log_probs_importance.append(log_prob_importance.item())

        logger.record_tabular("Uniform Data Log Prob (True Prior)",
                              np.mean(log_probs_prior))
        logger.record_tabular("Uniform Data Log Prob (Biased)",
                              np.mean(log_probs_biased))
        logger.record_tabular("Uniform Data Log Prob (Importance)",
                              np.mean(log_probs_importance))
        logger.record_tabular("Uniform Data KL", np.mean(kles))
        logger.record_tabular("Uniform Data MSE", np.mean(mses))

    def dump_uniform_imgs_and_reconstructions(self, dataset, epoch):
        idxs = np.random.choice(range(dataset.shape[0]), 4)
        filename = 'uniform{}.png'.format(epoch)
        imgs = []
        recons = []
        for i in idxs:
            img_np = dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.imsize,
                                 self.imsize).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize,
                              self.imsize).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        # save_file = osp.join(logger.get_snapshot_dir(), filename)
        # save_file = osp.join('/home/manh/project/visual_RL_imaged_goal', filename)
        project_path = osp.abspath(os.curdir)
        save_dir = osp.join(project_path + str('/result_image/'), filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=4,
        )
Exemplo n.º 7
0
    def __init__(
        self,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        linearity_weight=0.0,
        distance_weight=0.0,
        loss_weights=None,
        use_linear_dynamics=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
        key_to_reconstruct='observations',
        num_epochs=None,
    ):
        #TODO:steven fix pickling
        assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first"

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        self.num_epochs = num_epochs
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot
        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )

        self.key_to_reconstruct = key_to_reconstruct
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.linearity_weight = linearity_weight
        self.distance_weight = distance_weight
        self.loss_weights = loss_weights

        self.use_linear_dynamics = use_linear_dynamics
        self._extra_stats_to_log = None

        # stateful tracking variables, reset every epoch
        self.eval_statistics = collections.defaultdict(list)
        self.eval_data = collections.defaultdict(list)
        self.num_batches = 0
Exemplo n.º 8
0
class VAETrainer(LossFunction):
    def __init__(
        self,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        linearity_weight=0.0,
        distance_weight=0.0,
        loss_weights=None,
        use_linear_dynamics=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
        key_to_reconstruct='observations',
        num_epochs=None,
    ):
        #TODO:steven fix pickling
        assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first"

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        self.num_epochs = num_epochs
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot
        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )

        self.key_to_reconstruct = key_to_reconstruct
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.linearity_weight = linearity_weight
        self.distance_weight = distance_weight
        self.loss_weights = loss_weights

        self.use_linear_dynamics = use_linear_dynamics
        self._extra_stats_to_log = None

        # stateful tracking variables, reset every epoch
        self.eval_statistics = collections.defaultdict(list)
        self.eval_data = collections.defaultdict(list)
        self.num_batches = 0

    @property
    def log_dir(self):
        return logger.get_snapshot_dir()

    def get_dataset_stats(self, data):
        torch_input = ptu.from_numpy(normalize_image(data))
        mus, log_vars = self.model.encode(torch_input)
        mus = ptu.get_numpy(mus)
        mean = np.mean(mus, axis=0)
        std = np.std(mus, axis=0)
        return mus, mean, std

    def _kl_np_to_np(self, np_imgs):
        torch_input = ptu.from_numpy(normalize_image(np_imgs))
        mu, log_var = self.model.encode(torch_input)
        return ptu.get_numpy(
            -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))

    def _reconstruction_squared_error_np_to_np(self, np_imgs):
        torch_input = ptu.from_numpy(normalize_image(np_imgs))
        recons, *_ = self.model(torch_input)
        error = torch_input - recons
        return ptu.get_numpy((error**2).sum(dim=1))

    def set_vae(self, vae):
        self.model = vae
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

    def get_batch(self, test_data=False, epoch=None):
        if self.use_parallel_dataloading:
            if test_data:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader).to(ptu.device)
            return samples

        dataset = self.test_dataset if test_data else self.train_dataset
        skew = False
        if epoch is not None:
            skew = (self.start_skew_epoch < epoch)
        if not test_data and self.skew_dataset and skew:
            probs = self._train_weights / np.sum(self._train_weights)
            ind = np.random.choice(
                len(probs),
                self.batch_size,
                p=probs,
            )
        else:
            ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        if self.normalize:
            samples = ((samples - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            samples = samples - self.train_data_mean
        return ptu.from_numpy(samples)

    def get_debug_batch(self, train=True):
        dataset = self.train_dataset if train else self.test_dataset
        X, Y = dataset
        ind = np.random.randint(0, Y.shape[0], self.batch_size)
        X = X[ind, :]
        Y = Y[ind, :]
        return ptu.from_numpy(X), ptu.from_numpy(Y)

    def train_epoch(self, epoch, dataset, batches=100):
        start_time = time.time()
        for b in range(batches):
            self.train_batch(epoch, dataset.random_batch(self.batch_size))
        self.eval_statistics["train/epoch_duration"].append(time.time() -
                                                            start_time)

    def test_epoch(self, epoch, dataset, batches=10):
        start_time = time.time()
        for b in range(batches):
            self.test_batch(epoch, dataset.random_batch(self.batch_size))
        self.eval_statistics["test/epoch_duration"].append(time.time() -
                                                           start_time)

    def compute_loss(self, batch, epoch=-1, test=False):
        prefix = "test/" if test else "train/"

        beta = float(self.beta_schedule.get_value(epoch))
        obs = batch[self.key_to_reconstruct]
        reconstructions, obs_distribution_params, latent_distribution_params = self.model(
            obs)
        log_prob = self.model.logprob(obs, obs_distribution_params)
        kle = self.model.kl_divergence(latent_distribution_params)
        loss = -1 * log_prob + beta * kle

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['beta'] = beta
        self.eval_statistics[prefix + "losses"].append(loss.item())
        self.eval_statistics[prefix + "log_probs"].append(log_prob.item())
        self.eval_statistics[prefix + "kles"].append(kle.item())
        self.eval_statistics["num_train_batches"].append(self.num_batches)

        encoder_mean = self.model.get_encoding_from_latent_distribution_params(
            latent_distribution_params)
        z_data = ptu.get_numpy(encoder_mean.cpu())
        for i in range(len(z_data)):
            self.eval_data[prefix + "zs"].append(z_data[i, :])
        self.eval_data[prefix + "last_batch"] = (obs, reconstructions)

        return loss

    def train_batch(self, epoch, batch):
        self.num_batches += 1
        self.model.train()
        self.optimizer.zero_grad()

        loss = self.compute_loss(batch, epoch, False)
        loss.backward()

        self.optimizer.step()
        #self.scheduler.step()

    def test_batch(
        self,
        epoch,
        batch,
    ):
        self.model.eval()
        loss = self.compute_loss(batch, epoch, True)

    def end_epoch(self, epoch):
        self.eval_statistics = collections.defaultdict(list)
        self.test_last_batch = None

    def get_diagnostics(self):
        stats = OrderedDict()
        for k in sorted(self.eval_statistics.keys()):
            stats[k] = np.mean(self.eval_statistics[k])
        return stats

    def dump_scatterplot(self, z, epoch):
        try:
            import matplotlib.pyplot as plt
        except ImportError:
            logger.log(__file__ + ": Unable to load matplotlib. Consider "
                       "setting do_scatterplot to False")
            return
        dim_and_stds = [(i, np.std(z[:, i])) for i in range(z.shape[1])]
        dim_and_stds = sorted(dim_and_stds, key=lambda x: x[1])
        dim1 = dim_and_stds[-1][0]
        dim2 = dim_and_stds[-2][0]
        plt.figure(figsize=(8, 8))
        plt.scatter(z[:, dim1], z[:, dim2], marker='o', edgecolor='none')
        if self.model.dist_mu is not None:
            x1 = self.model.dist_mu[dim1:dim1 + 1]
            y1 = self.model.dist_mu[dim2:dim2 + 1]
            x2 = (self.model.dist_mu[dim1:dim1 + 1] +
                  self.model.dist_std[dim1:dim1 + 1])
            y2 = (self.model.dist_mu[dim2:dim2 + 1] +
                  self.model.dist_std[dim2:dim2 + 1])
        plt.plot([x1, x2], [y1, y2], color='k', linestyle='-', linewidth=2)
        axes = plt.gca()
        axes.set_xlim([-6, 6])
        axes.set_ylim([-6, 6])
        axes.set_title('dim {} vs dim {}'.format(dim1, dim2))
        plt.grid(True)
        save_file = osp.join(self.log_dir, 'scatter%d.png' % epoch)
        plt.savefig(save_file)
Exemplo n.º 9
0
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule, ConstantSchedule
    from rlkit.torch.vae.conv_vae import ConvVAE
    # from rlkit.torch.vae.conv_vae import (
    #     ConvVAE,
    #     ConvDynamicsVAE,
    #     SpatialAutoEncoder,
    #     AutoEncoder,
    # )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    import gym
    beta = variant["beta"]
    representation_size = variant.get(
        "representation_size",
        variant.get("latent_sizes", variant.get("embedding_dim", None)))
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    variant['algo_kwargs']['num_epochs'] = variant['num_epochs']
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    variant['generate_vae_dataset_kwargs']['batch_size'] = variant[
        'algo_kwargs']['batch_size']
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if 'context_schedule' in variant:
        schedule = variant['context_schedule']
        if type(schedule) is dict:
            context_schedule = PiecewiseLinearSchedule(**schedule)
        else:
            context_schedule = ConstantSchedule(schedule)
        variant['algo_kwargs']['context_schedule'] = context_schedule
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    if variant['algo_kwargs'].get('is_auto_encoder', False):
        model = AutoEncoder(representation_size,
                            decoder_output_activation=decoder_activation,
                            **variant['vae_kwargs'])
    elif variant.get('use_spatial_auto_encoder', False):
        model = SpatialAutoEncoder(
            representation_size,
            decoder_output_activation=decoder_activation,
            **variant['vae_kwargs'])
    elif variant.get('only_kwargs', False):
        vae_class = variant.get('vae_class', ConvVAE)
        model = vae_class(**variant['vae_kwargs'])
    else:
        vae_class = variant.get('vae_class', ConvVAE)
        if use_linear_dynamics:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              action_dim=action_dim,
                              **variant['vae_kwargs'])
        else:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              **variant['vae_kwargs'])

    model.to(ptu.device)

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'model', mode='pickle')

    if return_data:
        return model, train_dataset, test_dataset

    return model