def set_flat_params(model, flat_params, trainable_only=True):
    idx = 0
    # import ipdb; ipdb.set_trace()
    for p in model.parameters():
        flat_shape = int(np.prod(list(p.data.shape)))
        flat_params_to_assign = flat_params[idx:idx + flat_shape]

        if len(p.data.shape):
            p.data = ptu.tensor(flat_params_to_assign.reshape(*p.data.shape))
        else:
            p.data = ptu.tensor(flat_params_to_assign[0])
        idx += flat_shape
    return model
Exemplo n.º 2
0
def compute_world_model_loss(
    world_model,
    image_shape,
    image_dist,
    prior,
    post,
    prior_dist,
    post_dist,
    obs,
    forward_kl,
    free_nats,
    transition_loss_scale,
    kl_loss_scale,
    image_loss_scale,
):
    preprocessed_obs = world_model.flatten_obs(world_model.preprocess(obs),
                                               image_shape)
    image_pred_loss = -1 * image_dist.log_prob(preprocessed_obs).mean()
    post_detached_dist = world_model.get_detached_dist(post)
    prior_detached_dist = world_model.get_detached_dist(prior)
    if forward_kl:
        div = kld(post_dist, prior_dist).mean()
        div = torch.max(div, ptu.tensor(free_nats))
        prior_kld = kld(post_detached_dist, prior_dist).mean()
        post_kld = kld(post_dist, prior_detached_dist).mean()
    else:
        div = kld(prior_dist, post_dist).mean()
        div = torch.max(div, ptu.tensor(free_nats))
        prior_kld = kld(prior_dist, post_detached_dist).mean()
        post_kld = kld(prior_detached_dist, post_dist).mean()
    transition_loss = torch.max(prior_kld, ptu.tensor(free_nats))
    entropy_loss = torch.max(post_kld, ptu.tensor(free_nats))
    entropy_loss_scale = 1 - transition_loss_scale
    entropy_loss_scale = (1 - kl_loss_scale) * entropy_loss_scale
    transition_loss_scale = (1 - kl_loss_scale) * transition_loss_scale
    world_model_loss = (kl_loss_scale * div +
                        image_loss_scale * image_pred_loss +
                        transition_loss_scale * transition_loss +
                        entropy_loss_scale * entropy_loss)
    return world_model_loss, div, image_pred_loss, transition_loss, entropy_loss
Exemplo n.º 3
0
 def __init__(
     self,
     hidden_size,
     obs_dim,
     num_layers=4,
     discrete_continuous_dist=False,
     discrete_action_dim=0,
     continuous_action_dim=0,
     hidden_activation=F.elu,
     min_std=0.1,
     init_std=0.0,
     mean_scale=5.0,
     use_tanh_normal=True,
     dist="trunc_normal",
     **kwargs,
 ):
     self.discrete_continuous_dist = discrete_continuous_dist
     self.discrete_action_dim = discrete_action_dim
     self.continuous_action_dim = continuous_action_dim
     if self.discrete_continuous_dist:
         self.output_size = self.discrete_action_dim + self.continuous_action_dim * 2
     else:
         self.output_size = self.continuous_action_dim * 2
     super().__init__(
         [hidden_size] * num_layers,
         input_size=obs_dim,
         output_size=self.output_size,
         hidden_activation=hidden_activation,
         hidden_init=torch.nn.init.xavier_uniform_,
         **kwargs,
     )
     self._min_std = min_std
     self._mean_scale = mean_scale
     self.use_tanh_normal = use_tanh_normal
     self._dist = dist
     self.raw_init_std = torch.log(torch.exp(ptu.tensor(init_std)) - 1)
Exemplo n.º 4
0
    def __init__(
            self,
            env,
            context_graph,
            qf1,
            target_qf1,
            policy_n,
            cactor,
            qf2,
            target_qf2,
            deterministic_cactor_in_graph=True,
            deterministic_next_action=False,
            use_entropy_loss=True,
            use_entropy_reward=True,
            sum_n_loss=False, # use sum instead of mean for n agent losses
            use_cactor_entropy_loss=True,
            use_automatic_entropy_tuning=True,
            state_dependent_alpha=False,
            target_entropy=None,
            negative_sampling=False,

            discount=0.99,
            reward_scale=1.0,

            policy_learning_rate=1e-4,
            context_graph_learning_rate=1e-3,
            qf_learning_rate=1e-3, # not used
            qf_weight_decay=0.,
            init_alpha=1.,
            cactor_learning_rate=1e-4,
            target_hard_update_period=1000,
            tau=1e-2,
            use_soft_update=False,
            qf_criterion=None,
            pre_activation_weight=0.,
            optimizer_class=optim.Adam,

            min_q_value=-np.inf,
            max_q_value=np.inf,

            context_graph_optimizer=None,
            cactor_optimizer=None,
            policy_optimizer_n=None,
            alpha_optimizer_n=None,
            calpha_optimizer_n=None,
            log_alpha_n = None,
            log_calpha_n = None,
    ):
        super().__init__()
        self.env = env
        if qf_criterion is None:
            qf_criterion = nn.MSELoss()
        self.context_graph = context_graph
        self.qf1 = qf1
        self.target_qf1 = target_qf1
        self.qf2 = qf2
        self.target_qf2 = target_qf2
        self.policy_n = policy_n
        self.cactor = cactor

        self.deterministic_cactor_in_graph = deterministic_cactor_in_graph
        self.deterministic_next_action = deterministic_next_action
        self.sum_n_loss = sum_n_loss
        self.negative_sampling = negative_sampling

        self.discount = discount
        self.reward_scale = reward_scale

        self.policy_learning_rate = policy_learning_rate
        self.context_graph_learning_rate = context_graph_learning_rate
        self.qf_learning_rate = qf_learning_rate
        self.qf_weight_decay = qf_weight_decay
        self.cactor_learning_rate = cactor_learning_rate
        self.target_hard_update_period = target_hard_update_period
        self.tau = tau
        self.use_soft_update = use_soft_update
        self.qf_criterion = qf_criterion
        self.pre_activation_weight = pre_activation_weight
        self.min_q_value = min_q_value
        self.max_q_value = max_q_value

        if context_graph_optimizer:
            self.context_graph_optimizer = context_graph_optimizer
        else:
            self.context_graph_optimizer = optimizer_class(
                                    list(self.context_graph.parameters())\
                                    +list(self.qf1.parameters())\
                                    +list(self.qf2.parameters()),
                                    lr=self.context_graph_learning_rate,
                                 )

        if policy_optimizer_n:
            self.policy_optimizer_n = policy_optimizer_n
        else:
            self.policy_optimizer_n = [
                optimizer_class(
                    self.policy_n[i].parameters(),
                    lr=self.policy_learning_rate,
                ) for i in range(len(self.policy_n))]
        if cactor_optimizer:
            self.cactor_optimizer = cactor_optimizer
        else:
            self.cactor_optimizer = optimizer_class(
                                        self.cactor.parameters(),
                                        lr=self.cactor_learning_rate,
                                    )

        self.init_alpha = init_alpha
        self.use_entropy_loss = use_entropy_loss
        self.use_entropy_reward = use_entropy_reward
        self.use_cactor_entropy_loss = use_cactor_entropy_loss
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        self.state_dependent_alpha = state_dependent_alpha
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()  # heuristic value from Tuomas
            if self.use_entropy_loss:
                if log_alpha_n:
                    self.log_alpha_n = log_alpha_n
                else:
                    self.log_alpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))]
                if alpha_optimizer_n:
                    self.alpha_optimizer_n = alpha_optimizer_n
                else:
                    if self.state_dependent_alpha:
                        self.alpha_optimizer_n = [
                                optimizer_class(
                                    self.log_alpha_n[i].parameters(),
                                    lr=self.policy_learning_rate,
                                ) for i in range(len(self.log_alpha_n))]
                    else:
                        self.alpha_optimizer_n = [
                            optimizer_class(
                                [self.log_alpha_n[i]],
                                lr=self.policy_learning_rate,
                            ) for i in range(len(self.log_alpha_n))]

            if self.use_cactor_entropy_loss:
                if log_calpha_n:
                    self.log_calpha_n = log_calpha_n
                else:
                    self.log_calpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))]
                if calpha_optimizer_n:
                    self.calpha_optimizer_n = calpha_optimizer_n
                else:
                    if self.state_dependent_alpha:
                        self.calpha_optimizer_n = [
                                optimizer_class(
                                    self.log_calpha_n[i].parameters(),
                                    lr=self.policy_learning_rate,
                                ) for i in range(len(self.log_calpha_n))]
                    else:
                        self.calpha_optimizer_n = [
                            optimizer_class(
                                [self.log_calpha_n[i]],
                                lr=self.policy_learning_rate,
                            ) for i in range(len(self.log_calpha_n))]

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
Exemplo n.º 5
0
    def __init__(
            self,
            env,
            cg1,
            target_cg1,
            qf1_n,
            target_qf1_n,
            cg2,
            target_cg2,
            qf2_n,
            target_qf2_n,
            cgca,
            cactor_n,
            policy_n,
            deterministic_cactor_in_graph=True,
            deterministic_next_action=False,
            use_entropy_loss=True,
            use_entropy_reward=True,
            use_cactor_entropy_loss=True,
            use_automatic_entropy_tuning=True,
            target_entropy=None,

            discount=0.99,
            reward_scale=1.0,

            policy_learning_rate=1e-4,
            qf_learning_rate=1e-3,
            qf_weight_decay=0.,
            init_alpha=1.,
            cactor_learning_rate=1e-4,
            target_hard_update_period=1000,
            tau=1e-2,
            use_soft_update=False,
            qf_criterion=None,
            pre_activation_weight=0.,
            optimizer_class=optim.Adam,

            min_q_value=-np.inf,
            max_q_value=np.inf,

            qf1_optimizer=None,
            qf2_optimizer=None,
            cactor_optimizer=None,
            policy_optimizer_n=None,
            alpha_optimizer_n=None,
            calpha_optimizer=None,
            log_alpha_n = None,
            log_calpha_n = None,
    ):
        super().__init__()
        self.env = env
        if qf_criterion is None:
            qf_criterion = nn.MSELoss()
        self.cg1 = cg1
        self.target_cg1 = target_cg1
        self.qf1_n = qf1_n
        self.target_qf1_n = target_qf1_n
        self.cg2 = cg2
        self.target_cg2 = target_cg2
        self.qf2_n = qf2_n
        self.target_qf2_n = target_qf2_n
        self.cgca = cgca
        self.cactor_n = cactor_n
        self.policy_n = policy_n

        self.deterministic_cactor_in_graph = deterministic_cactor_in_graph
        self.deterministic_next_action = deterministic_next_action

        self.discount = discount
        self.reward_scale = reward_scale

        self.policy_learning_rate = policy_learning_rate
        self.qf_learning_rate = qf_learning_rate
        self.qf_weight_decay = qf_weight_decay
        self.cactor_learning_rate = cactor_learning_rate
        self.target_hard_update_period = target_hard_update_period
        self.tau = tau
        self.use_soft_update = use_soft_update
        self.qf_criterion = qf_criterion
        self.pre_activation_weight = pre_activation_weight
        self.min_q_value = min_q_value
        self.max_q_value = max_q_value

        if qf1_optimizer:
            self.qf1_optimizer = qf1_optimizer
        else:
            qf1_parameters = list(self.cg1.parameters())
            for qf1 in qf1_n:
                qf1_parameters += list(qf1.parameters())
            self.qf1_optimizer = optimizer_class(
                                    qf1_parameters,
                                    lr=self.qf_learning_rate,
                                 )
        if qf2_optimizer:
            self.qf2_optimizer = qf2_optimizer
        else:
            qf2_parameters = list(self.cg2.parameters())
            for qf2 in qf2_n:
                qf2_parameters += list(qf2.parameters())
            self.qf2_optimizer = optimizer_class(
                                    qf2_parameters,
                                    lr=self.qf_learning_rate,
                                 )
        if policy_optimizer_n:
            self.policy_optimizer_n = policy_optimizer_n
        else:
            self.policy_optimizer_n = [
                optimizer_class(
                    self.policy_n[i].parameters(),
                    lr=self.policy_learning_rate,
                ) for i in range(len(self.policy_n))]
        if cactor_optimizer:
            self.cactor_optimizer = cactor_optimizer
        else:
            cactor_parameters = list(self.cgca.parameters())
            for cactor in cactor_n:
                cactor_parameters += list(cactor.parameters())
            self.cactor_optimizer = optimizer_class(
                                        cactor_parameters,
                                        lr=self.cactor_learning_rate,
                                    )

        self.init_alpha = init_alpha
        self.use_entropy_loss = use_entropy_loss
        self.use_entropy_reward = use_entropy_reward
        self.use_cactor_entropy_loss = use_cactor_entropy_loss
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()  # heuristic value from Tuomas
            if self.use_entropy_loss:
                if log_alpha_n:
                    self.log_alpha_n = log_alpha_n
                else:
                    self.log_alpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))]
                if alpha_optimizer_n:
                    self.alpha_optimizer_n = alpha_optimizer_n
                else:
                    self.alpha_optimizer_n = [
                        optimizer_class(
                            [self.log_alpha_n[i]],
                            lr=self.policy_learning_rate,
                        ) for i in range(len(self.log_alpha_n))]

            if self.use_cactor_entropy_loss:
                if log_calpha_n:
                    self.log_calpha_n = log_calpha_n
                else:
                    self.log_calpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))]
                if calpha_optimizer:
                    self.calpha_optimizer = calpha_optimizer
                else:
                    self.calpha_optimizer = \
                        optimizer_class(
                            self.log_calpha_n,
                            lr=self.policy_learning_rate,
                        )

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
Exemplo n.º 6
0
    def train_from_torch(
        self,
        batch,
        train=True,
        pretrain=False,
    ):
        """

        :param batch:
        :param train:
        :param pretrain:
        :return:
        """
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.brac:
            buf_dist = self.buffer_policy(obs)
            buf_log_pi = buf_dist.log_prob(actions)
            rewards = rewards + buf_log_pi

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.awr_use_mle_for_vf:
            v1_pi = self.qf1(obs, policy_mle)
            v2_pi = self.qf2(obs, policy_mle)
            v_pi = torch.min(v1_pi, v2_pi)
        else:
            if self.vf_K > 1:
                vs = []
                for i in range(self.vf_K):
                    u = dist.sample()
                    q1 = self.qf1(obs, u)
                    q2 = self.qf2(obs, u)
                    v = torch.min(q1, q2)
                    # v = q1
                    vs.append(v)
                v_pi = torch.cat(vs, 1).mean(dim=1)
            else:
                # v_pi = self.qf1(obs, new_obs_actions)
                v1_pi = self.qf1(obs, new_obs_actions)
                v2_pi = self.qf2(obs, new_obs_actions)
                v_pi = torch.min(v1_pi, v2_pi)

        if self.awr_sample_actions:
            u = new_obs_actions
            if self.awr_min_q:
                q_adv = q_new_actions
            else:
                q_adv = qf1_new_actions
        elif self.buffer_policy_sample_actions:
            buf_dist = self.buffer_policy(obs)
            u, _ = buf_dist.rsample_and_logprob()
            qf1_buffer_actions = self.qf1(obs, u)
            qf2_buffer_actions = self.qf2(obs, u)
            q_buffer_actions = torch.min(
                qf1_buffer_actions,
                qf2_buffer_actions,
            )
            if self.awr_min_q:
                q_adv = q_buffer_actions
            else:
                q_adv = qf1_buffer_actions
        else:
            u = actions
            if self.awr_min_q:
                q_adv = torch.min(q1_pred, q2_pred)
            else:
                q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)

        if self.normalize_over_state == "advantage":
            score = q_adv - v_pi
            if self.mask_positive_advantage:
                score = torch.sign(score)
        elif self.normalize_over_state == "Z":
            buffer_dist = self.buffer_policy(obs)
            K = self.Z_K
            buffer_obs = []
            buffer_actions = []
            log_bs = []
            log_pis = []
            for i in range(K):
                u = buffer_dist.sample()
                log_b = buffer_dist.log_prob(u)
                log_pi = dist.log_prob(u)
                buffer_obs.append(obs)
                buffer_actions.append(u)
                log_bs.append(log_b)
                log_pis.append(log_pi)
            buffer_obs = torch.cat(buffer_obs, 0)
            buffer_actions = torch.cat(buffer_actions, 0)
            p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, ))
            log_pi = torch.cat(log_pis, 0)
            log_pi = log_pi.sum(dim=1, )
            q1_b = self.qf1(buffer_obs, buffer_actions)
            q2_b = self.qf2(buffer_obs, buffer_actions)
            q_b = torch.min(q1_b, q2_b)
            q_b = torch.reshape(q_b, (-1, K))
            adv_b = q_b - v_pi
            # if self._n_train_steps_total % 100 == 0:
            #     import ipdb; ipdb.set_trace()
            # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True)
            # score = torch.exp((q_adv - v_pi) / beta) / Z
            # score = score / sum(score)
            logK = torch.log(ptu.tensor(float(K)))
            logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True)
            logS = (q_adv - v_pi) / beta - logZ
            # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True)
            # logS = q_adv/beta - logZ
            score = F.softmax(logS, dim=0)  # score / sum(score)
        else:
            error

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        if self.weight_loss and weights is None:
            if self.normalize_over_batch:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif not self.normalize_over_batch:
                weights = score
            else:
                error
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        if self.compute_bc:
            train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            policy_loss = policy_loss + self.bc_weight * train_policy_loss

        if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0:
            del self.buffer_policy_optimizer
            self.buffer_policy_optimizer = self.optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=self.policy_weight_decay,
                lr=self.policy_lr,
            )
            self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer
            for i in range(self.num_buffer_policy_train_steps_on_reset):
                if self.train_bc_on_rl_buffer:
                    if self.advantage_weighted_buffer_loss:
                        buffer_dist = self.buffer_policy(obs)
                        buffer_u = actions
                        buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob(
                        )
                        buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                        buffer_policy_logpp = buffer_policy_logpp[:, None]

                        buffer_q1_pred = self.qf1(obs, buffer_u)
                        buffer_q2_pred = self.qf2(obs, buffer_u)
                        buffer_q_adv = torch.min(buffer_q1_pred,
                                                 buffer_q2_pred)

                        buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                        buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                        buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                        buffer_score = buffer_q_adv - buffer_v_pi
                        buffer_weights = F.softmax(buffer_score / beta, dim=0)
                        buffer_policy_loss = self.awr_weight * (
                            -buffer_policy_logpp * len(buffer_weights) *
                            buffer_weights.detach()).mean()
                    else:
                        buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                            self.replay_buffer.train_replay_buffer,
                            self.buffer_policy)

                    self.buffer_policy_optimizer.zero_grad()
                    buffer_policy_loss.backward(retain_graph=True)
                    self.buffer_policy_optimizer.step()

        if self.train_bc_on_rl_buffer:
            if self.advantage_weighted_buffer_loss:
                buffer_dist = self.buffer_policy(obs)
                buffer_u = actions
                buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob()
                buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                buffer_policy_logpp = buffer_policy_logpp[:, None]

                buffer_q1_pred = self.qf1(obs, buffer_u)
                buffer_q2_pred = self.qf2(obs, buffer_u)
                buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred)

                buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                buffer_score = buffer_q_adv - buffer_v_pi
                buffer_weights = F.softmax(buffer_score / beta, dim=0)
                buffer_policy_loss = self.awr_weight * (
                    -buffer_policy_logpp * len(buffer_weights) *
                    buffer_weights.detach()).mean()
            else:
                buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0:
            self.buffer_policy_optimizer.zero_grad()
            buffer_policy_loss.backward()
            self.buffer_policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))

            if self.normalize_over_state == "Z":
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'logZ',
                        ptu.get_numpy(logZ),
                    ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.compute_bc:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.policy)
                self.eval_statistics.update({
                    "bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                })
            if self.train_bc_on_rl_buffer:
                _, buffer_train_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)

                _, buffer_test_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.validation_replay_buffer,
                    self.buffer_policy)
                buffer_dist = self.buffer_policy(obs)
                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)

                _, train_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_train_buffer, self.buffer_policy)

                _, test_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.buffer_policy)

                self.eval_statistics.update({
                    "buffer_policy/Train Online Logprob":
                    -1 * ptu.get_numpy(buffer_train_logp_loss),
                    "buffer_policy/Test Online Logprob":
                    -1 * ptu.get_numpy(buffer_test_logp_loss),
                    "buffer_policy/Train Offline Logprob":
                    -1 * ptu.get_numpy(train_offline_logp_loss),
                    "buffer_policy/Test Offline Logprob":
                    -1 * ptu.get_numpy(test_offline_logp_loss),
                    "buffer_policy/train_policy_loss":
                    ptu.get_numpy(buffer_policy_loss),
                    # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss),
                    "buffer_policy/kl_div":
                    ptu.get_numpy(kldiv.mean()),
                })
            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

            if self.validation_qlearning:
                train_data = self.replay_buffer.validation_replay_buffer.random_batch(
                    self.bc_batch_size)
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                # goals = train_data['resampled_goals']
                train_data[
                    'observations'] = obs  # torch.cat((obs, goals), dim=1)
                train_data[
                    'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
                self.test_from_torch(train_data)

        self._n_train_steps_total += 1
Exemplo n.º 7
0
    def __init__(
        self,
        env,
        qf1_n,
        target_qf1_n,
        policy_n,
        cactor_n,
        online_action,
        qf2_n,
        target_qf2_n,
        deterministic_cactor_in_graph=True,
        deterministic_next_action=False,
        prg_next_action=True,
        use_entropy_loss=True,
        use_entropy_reward=True,
        use_cactor_entropy_loss=True,
        use_automatic_entropy_tuning=True,
        state_dependent_alpha=False,
        target_entropy=None,
        dec_cactor=True,  # each cactor only gets its own observation
        logit_level=1,
        discount=0.99,
        reward_scale=1.0,
        policy_learning_rate=1e-4,
        qf_learning_rate=1e-3,
        qf_weight_decay=0.,
        init_alpha=1.,
        cactor_learning_rate=1e-4,
        target_hard_update_period=1000,
        tau=1e-2,
        use_soft_update=False,
        qf_criterion=None,
        pre_activation_weight=0.,
        optimizer_class=optim.Adam,
        min_q_value=-np.inf,
        max_q_value=np.inf,
        qf1_optimizer_n=None,
        qf2_optimizer_n=None,
        policy_optimizer_n=None,
        cactor_optimizer_n=None,
        alpha_optimizer_n=None,
        calpha_optimizer_n=None,
        log_alpha_n=None,
        log_calpha_n=None,
    ):
        super().__init__()
        self.env = env
        if qf_criterion is None:
            qf_criterion = nn.MSELoss()
        self.qf1_n = qf1_n
        self.target_qf1_n = target_qf1_n
        self.qf2_n = qf2_n
        self.target_qf2_n = target_qf2_n
        self.policy_n = policy_n
        self.cactor_n = cactor_n

        self.online_action = online_action
        self.logit_level = logit_level
        self.deterministic_cactor_in_graph = deterministic_cactor_in_graph
        self.deterministic_next_action = deterministic_next_action
        self.prg_next_action = prg_next_action
        self.dec_cactor = dec_cactor

        self.discount = discount
        self.reward_scale = reward_scale

        self.policy_learning_rate = policy_learning_rate
        self.qf_learning_rate = qf_learning_rate
        self.qf_weight_decay = qf_weight_decay
        self.cactor_learning_rate = cactor_learning_rate
        self.target_hard_update_period = target_hard_update_period
        self.tau = tau
        self.use_soft_update = use_soft_update
        self.qf_criterion = qf_criterion
        self.pre_activation_weight = pre_activation_weight
        self.min_q_value = min_q_value
        self.max_q_value = max_q_value

        if qf1_optimizer_n:
            self.qf1_optimizer_n = qf1_optimizer_n
        else:
            self.qf1_optimizer_n = [
                optimizer_class(
                    self.qf1_n[i].parameters(),
                    lr=self.qf_learning_rate,
                ) for i in range(len(self.qf1_n))
            ]
        if qf2_optimizer_n:
            self.qf2_optimizer_n = qf2_optimizer_n
        else:
            self.qf2_optimizer_n = [
                optimizer_class(
                    self.qf2_n[i].parameters(),
                    lr=self.qf_learning_rate,
                ) for i in range(len(self.qf2_n))
            ]
        if policy_optimizer_n:
            self.policy_optimizer_n = policy_optimizer_n
        else:
            self.policy_optimizer_n = [
                optimizer_class(
                    self.policy_n[i].parameters(),
                    lr=self.policy_learning_rate,
                ) for i in range(len(self.policy_n))
            ]
        if cactor_optimizer_n:
            self.cactor_optimizer_n = cactor_optimizer_n
        else:
            self.cactor_optimizer_n = [
                optimizer_class(
                    self.cactor_n[i].parameters(),
                    lr=self.cactor_learning_rate,
                ) for i in range(len(self.cactor_n))
            ]

        self.init_alpha = init_alpha
        self.use_entropy_loss = use_entropy_loss
        self.use_entropy_reward = use_entropy_reward
        self.use_cactor_entropy_loss = use_cactor_entropy_loss
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        self.state_dependent_alpha = state_dependent_alpha
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(
                    self.env.action_space.shape).item(
                    )  # heuristic value from Tuomas
            if self.use_entropy_loss:
                if log_alpha_n:
                    self.log_alpha_n = log_alpha_n
                else:
                    self.log_alpha_n = [
                        ptu.tensor([np.log(self.init_alpha)],
                                   requires_grad=True,
                                   dtype=torch.float32)
                        for i in range(len(self.policy_n))
                    ]
                if alpha_optimizer_n:
                    self.alpha_optimizer_n = alpha_optimizer_n
                else:
                    if self.state_dependent_alpha:
                        self.alpha_optimizer_n = [
                            optimizer_class(
                                self.log_alpha_n[i].parameters(),
                                lr=self.policy_learning_rate,
                            ) for i in range(len(self.log_alpha_n))
                        ]
                    else:
                        self.alpha_optimizer_n = [
                            optimizer_class(
                                [self.log_alpha_n[i]],
                                lr=self.policy_learning_rate,
                            ) for i in range(len(self.log_alpha_n))
                        ]

            if self.use_cactor_entropy_loss:
                if log_calpha_n:
                    self.log_calpha_n = log_calpha_n
                else:
                    self.log_calpha_n = [
                        ptu.tensor([np.log(self.init_alpha)],
                                   requires_grad=True,
                                   dtype=torch.float32)
                        for i in range(len(self.policy_n))
                    ]
                if calpha_optimizer_n:
                    self.calpha_optimizer_n = calpha_optimizer_n
                else:
                    if self.state_dependent_alpha:
                        self.calpha_optimizer_n = [
                            optimizer_class(
                                self.log_calpha_n[i].parameters(),
                                lr=self.policy_learning_rate,
                            ) for i in range(len(self.log_calpha_n))
                        ]
                    else:
                        self.calpha_optimizer_n = [
                            optimizer_class(
                                [self.log_calpha_n[i]],
                                lr=self.policy_learning_rate,
                            ) for i in range(len(self.log_calpha_n))
                        ]

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
Exemplo n.º 8
0
 def log_abs_det_jacobian(self, x, y):
     return 2.0 * (torch.log(ptu.tensor(2.0)) - x - F.softplus(-2.0 * x))
Exemplo n.º 9
0
def world_model_loss_rt(
    world_model,
    image_shape,
    image_dist,
    reward_dist,
    prior,
    post,
    prior_dist,
    post_dist,
    pred_discount_dist,
    obs,
    rewards,
    terminals,
    forward_kl,
    free_nats,
    transition_loss_scale,
    kl_loss_scale,
    image_loss_scale,
    reward_loss_scale,
    pred_discount_loss_scale,
    discount,
):
    preprocessed_obs = world_model.flatten_obs(world_model.preprocess(obs),
                                               image_shape)
    image_pred_loss = -1 * image_dist.log_prob(preprocessed_obs).mean()
    post_detached_dist = world_model.get_detached_dist(post)
    prior_detached_dist = world_model.get_detached_dist(prior)
    reward_pred_loss = -1 * reward_dist.log_prob(rewards).mean()
    pred_discount_target = discount * (1 - terminals.float())
    pred_discount_loss = -1 * pred_discount_dist.log_prob(
        pred_discount_target).mean()
    if forward_kl:
        div = kld(post_dist, prior_dist).mean()
        div = torch.max(div, ptu.tensor(free_nats))
        prior_kld = kld(post_detached_dist, prior_dist).mean()
        post_kld = kld(post_dist, prior_detached_dist).mean()
    else:
        div = kld(prior_dist, post_dist).mean()
        div = torch.max(div, ptu.tensor(free_nats))
        prior_kld = kld(prior_dist, post_detached_dist).mean()
        post_kld = kld(prior_detached_dist, post_dist).mean()
    transition_loss = torch.max(prior_kld, ptu.tensor(free_nats))
    entropy_loss = torch.max(post_kld, ptu.tensor(free_nats))
    entropy_loss_scale = 1 - transition_loss_scale
    entropy_loss_scale = (1 - kl_loss_scale) * entropy_loss_scale
    transition_loss_scale = (1 - kl_loss_scale) * transition_loss_scale
    world_model_loss = (kl_loss_scale * div +
                        image_loss_scale * image_pred_loss +
                        transition_loss_scale * transition_loss +
                        entropy_loss_scale * entropy_loss +
                        reward_loss_scale * reward_pred_loss +
                        pred_discount_loss_scale * pred_discount_loss)
    return (
        world_model_loss,
        div,
        image_pred_loss,
        reward_pred_loss,
        transition_loss,
        entropy_loss,
        pred_discount_loss,
    )
Exemplo n.º 10
0
    def __init__(
        self,
        actor,
        vf,
        target_vf,
        world_model,
        image_shape,
        imagination_horizon=15,
        discount=0.99,
        actor_lr=8e-5,
        vf_lr=8e-5,
        world_model_lr=3e-4,
        world_model_gradient_clip=100.0,
        actor_gradient_clip=100.0,
        value_gradient_clip=100.0,
        adam_eps=1e-5,
        weight_decay=0.0,
        soft_target_tau=1,
        target_update_period=100,
        lam=0.95,
        free_nats=1.0,
        kl_loss_scale=0.0,
        pred_discount_loss_scale=10.0,
        image_loss_scale=1.0,
        reward_loss_scale=2.0,
        transition_loss_scale=0.8,
        detach_rewards=False,
        forward_kl=False,
        policy_gradient_loss_scale=0.0,
        actor_entropy_loss_schedule="1e-4",
        use_pred_discount=False,
        reward_scale=1,
        num_imagination_iterations=1,
        use_baseline=True,
        use_ppo_loss=False,
        ppo_clip_param=0.2,
        num_actor_value_updates=1,
        use_advantage_normalization=False,
        use_clipped_value_loss=False,
        actor_value_lr=8e-5,
        use_actor_value_optimizer=False,
        binarize_rewards=False,
    ):
        super().__init__()

        torch.backends.cudnn.benchmark = True

        self.scaler = torch.cuda.amp.GradScaler()
        self.use_pred_discount = use_pred_discount
        self.actor = actor.to(ptu.device)
        self.world_model = world_model.to(ptu.device)
        self.vf = vf.to(ptu.device)
        self.target_vf = target_vf.to(ptu.device)

        optimizer_class = optim.Adam

        self.actor_lr = actor_lr
        self.adam_eps = adam_eps
        self.weight_decay = weight_decay
        self.vf_lr = vf_lr
        self.world_model_lr = world_model_lr

        self.actor_optimizer = optimizer_class(
            self.actor.parameters(),
            lr=actor_lr,
            eps=adam_eps,
            weight_decay=weight_decay,
        )
        self.vf_optimizer = optimizer_class(
            self.vf.parameters(),
            lr=vf_lr,
            eps=adam_eps,
            weight_decay=weight_decay,
        )
        self.world_model_optimizer = optimizer_class(
            self.world_model.parameters(),
            lr=world_model_lr,
            eps=adam_eps,
            weight_decay=weight_decay,
        )
        self.use_actor_value_optimizer = use_actor_value_optimizer
        self.actor_value_optimizer = optimizer_class(
            list(self.actor.parameters()) + list(self.vf.parameters()),
            lr=actor_value_lr,
            eps=adam_eps,
            weight_decay=weight_decay,
        )

        self.discount = discount
        self.lam = lam
        self.imagination_horizon = imagination_horizon
        self.free_nats = ptu.tensor(free_nats)
        self.kl_loss_scale = kl_loss_scale
        self.pred_discount_loss_scale = pred_discount_loss_scale
        self.image_loss_scale = image_loss_scale
        self.reward_loss_scale = reward_loss_scale
        self.transition_loss_scale = transition_loss_scale
        self.policy_gradient_loss_scale = policy_gradient_loss_scale
        self.actor_entropy_loss_schedule = actor_entropy_loss_schedule
        self.actor_entropy_loss_scale = lambda x=actor_entropy_loss_schedule: schedule(
            x, self._n_train_steps_total)
        self.forward_kl = forward_kl
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period
        self.image_shape = image_shape
        self.use_baseline = use_baseline
        self.use_ppo_loss = use_ppo_loss
        self.ppo_clip_param = ppo_clip_param
        self.num_actor_value_updates = num_actor_value_updates
        self.world_model_gradient_clip = world_model_gradient_clip
        self.actor_gradient_clip = actor_gradient_clip
        self.value_gradient_clip = value_gradient_clip
        self.use_advantage_normalization = use_advantage_normalization
        self.detach_rewards = detach_rewards
        self.num_imagination_iterations = num_imagination_iterations
        self.use_clipped_value_loss = use_clipped_value_loss
        self.reward_scale = reward_scale
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()
        self.use_dynamics_backprop = self.policy_gradient_loss_scale < 1.0
        self.binarize_rewards = binarize_rewards
Exemplo n.º 11
0
    def __init__(
            self,
            env,
            qf1,
            target_qf1,
            qf2,
            target_qf2,
            policy_n,

            shared_gnn=None,

            discount=0.99,
            reward_scale=1.0,

            policy_learning_rate=1e-4,
            qf_learning_rate=1e-3,
            qf_weight_decay=0.,
            log_alpha_n=None,
            init_alpha=1.,
            target_hard_update_period=1000,
            tau=1e-2,
            use_soft_update=False,
            qf_criterion=None,
            deterministic_next_action=False,
            use_entropy_reward=False,
            use_automatic_entropy_tuning=True,
            target_entropy=None,
            optimizer_class=optim.Adam,
            log_grad=False,
            shared_obs=False,

            min_q_value=-np.inf,
            max_q_value=np.inf,

            qf1_optimizer=None,
            qf2_optimizer=None,
            policy_optimizer_n=None,
            alpha_optimizer_n=None,
            shared_gnn_optimizer=None,
    ):
        super().__init__()
        if qf_criterion is None:
            qf_criterion = nn.MSELoss()
        self.env = env
        self.qf1 = qf1
        self.target_qf1 = target_qf1
        self.qf2 = qf2
        self.target_qf2 = target_qf2
        self.policy_n = policy_n
        self.shared_gnn = shared_gnn
        self.deterministic_next_action = deterministic_next_action
        self.use_entropy_reward = use_entropy_reward

        self.discount = discount
        self.reward_scale = reward_scale

        self.policy_learning_rate = policy_learning_rate
        self.qf_learning_rate = qf_learning_rate
        self.qf_weight_decay = qf_weight_decay
        self.target_hard_update_period = target_hard_update_period
        self.tau = tau
        self.use_soft_update = use_soft_update
        self.qf_criterion = qf_criterion
        self.min_q_value = min_q_value
        self.max_q_value = max_q_value
        self.log_grad = log_grad
        self.shared_obs = shared_obs

        self.init_alpha = init_alpha
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()  # heuristic value from Tuomas
            if log_alpha_n:
                self.log_alpha_n = log_alpha_n
            else:
                self.log_alpha_n = [ptu.tensor([np.log(self.init_alpha)], requires_grad=True, dtype=torch.float32) for i in range(len(self.policy_n))]
            if alpha_optimizer_n:
                self.alpha_optimizer_n = alpha_optimizer_n
            else:
                self.alpha_optimizer_n = [
                    optimizer_class(
                        [self.log_alpha_n[i]],
                        lr=self.policy_learning_rate,
                    ) for i in range(len(self.log_alpha_n))]

        if qf1_optimizer:
            self.qf1_optimizer = qf1_optimizer
        else:
            self.qf1_optimizer = optimizer_class(
                                    self.qf1.parameters(),
                                    lr=self.qf_learning_rate,
                                )
        if qf2_optimizer:
            self.qf2_optimizer = qf2_optimizer
        else:
            self.qf2_optimizer = optimizer_class(
                                    self.qf2.parameters(),
                                    lr=self.qf_learning_rate,
                                )
        if policy_optimizer_n:
            self.policy_optimizer_n = policy_optimizer_n
        else:
            self.policy_optimizer_n = [
                optimizer_class(
                    self.policy_n[i].parameters(),
                    lr=self.policy_learning_rate,
                ) for i in range(len(self.policy_n))]
        if shared_gnn:
            if shared_gnn_optimizer:
                self.shared_gnn_optimizer = shared_gnn_optimizer
            else:
                self.shared_gnn_optimizer = optimizer_class(
                                            self.shared_gnn.parameters(),
                                            lr=self.policy_learning_rate/len(self.policy_n),
                                        )

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True