コード例 #1
0
    def select_actions(self, obs, raw_context):

        # Repeat the obs as what BCQ has done,
        # candidate_size here indicates how many
        # candidate actions we need.
        obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1)))
        if len(raw_context) == 0:
            # In the beginning, the inferred_mdp is set to zero vector.
            inferred_mdp = ptu.zeros((1, self.f.latent_dim))
        else:
            # Construct the context from raw context
            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            inferred_mdp = self.f(context)
        with torch.no_grad():
            inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1)
            z = from_numpy(
                np.random.normal(0, 1, size=(obs.size(0),
                                             self.vae_latent_dim))).clamp(
                                                 -0.5, 0.5).to(ptu.device)
            candidate_actions = self.vae_decoder(obs, z, inferred_mdp)
            perturbed_actions = self.perturbation_generator.get_perturbed_actions(
                obs, candidate_actions, inferred_mdp)
            qv = self.Qs(obs, perturbed_actions, inferred_mdp)
            ind = qv.max(0)[1]
        return ptu.get_numpy(perturbed_actions[ind])
コード例 #2
0
def _elem_or_tuple_to_variable(elem_or_tuple):
    if isinstance(elem_or_tuple, tuple):
        return tuple(_elem_or_tuple_to_variable(e) for e in elem_or_tuple)
    elif isinstance(elem_or_tuple, OrderedDict) or isinstance(
            elem_or_tuple, dict):
        return {k: ptu.from_numpy(v).float() for k, v in elem_or_tuple.items()}

    return ptu.from_numpy(elem_or_tuple).float()
コード例 #3
0
def torch_ify(np_array_or_other):
    if isinstance(np_array_or_other, np.ndarray):
        return ptu.from_numpy(np_array_or_other)
    elif isinstance(np_array_or_other, OrderedDict) or isinstance(
            np_array_or_other, dict):
        return {k: ptu.from_numpy(v) for k, v in np_array_or_other.items()}
    else:
        return np_array_or_other
コード例 #4
0
def _get_prod_of_gauss_mask(num_selected, desired_len):

    # Taken from
    # https://discuss.pytorch.org/t/create-a-2d-tensor-with-varying-lengths-of-one-in-each-row/25359

    # desired_length is the desired size of the second dimension of the masks

    seq_lens = ptu.from_numpy(np.array(num_selected)).unsqueeze(-1)
    max_len = torch.max(seq_lens)

    # create tensor of suitable shape and same number of dimensions
    range_tensor = torch.arange(max_len).unsqueeze(0)
    range_tensor = range_tensor.to(ptu.device)
    range_tensor = range_tensor.expand(seq_lens.size(0), range_tensor.size(1))

    # until this step, we only created auxiliary tensors (you may already have from previous steps)
    # the real mask tensor is created with binary masking:
    mask_tensor = (range_tensor < seq_lens)

    mask_tensor = mask_tensor.type(torch.float)

    current_len = mask_tensor.shape[1]

    pad = ptu.zeros(mask_tensor.shape[0], desired_len - current_len)

    mask_tensor = torch.cat((mask_tensor, pad), dim=1)

    return mask_tensor
コード例 #5
0
    def select_actions(self, obs, inferred_mdp):

        # Repeat the obs as what BCQ has done,
        # candidate_size here indicates how many
        # candidate actions we need.
        obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1)))
        with torch.no_grad():
            inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1)
            z = from_numpy(
                np.random.normal(0, 1, size=(obs.size(0),
                                             self.vae_latent_dim))).clamp(
                                                 -0.5, 0.5).to(ptu.device)
            candidate_actions = self.vae_decoder(obs, z, inferred_mdp)
            perturbed_actions = self.perturbation_generator.get_perturbed_actions(
                obs, candidate_actions, inferred_mdp)
            qv = self.Qs(obs, perturbed_actions, inferred_mdp)
            ind = qv.max(0)[1]
        return ptu.get_numpy(perturbed_actions[ind])
コード例 #6
0
def calculate_rewards(next_obs, goals):
    pos = next_obs[:,:2][None]
    pos = pos.expand(len(goals), pos.shape[1], pos.shape[2])

    goals_np = np.array(goals)
    goals_pt = ptu.from_numpy(goals_np)
    goals_pt = goals_pt.unsqueeze(1)
    reward = torch.exp(-torch.norm(pos - goals_pt, dim=-1))
    return reward
コード例 #7
0
    def async_evaluate(self, goal):
        self.env.set_goal(goal)
        self.policy.context_encoder.clear_z()

        avg_reward = 0.
        avg_achieved = []
        final_achieved = []

        raw_context = deque()
        for i in range(self.num_evals):
            # Sample MDP indentity
            self.policy.context_encoder.sample_z()
            inferred_mdp = self.policy.context_encoder.z

            obs = self.env.reset()
            done = False
            path_length = 0

            while not done and path_length < self.max_path_length:
                action = self.select_actions(np.array(obs), inferred_mdp)
                next_obs, reward, done, env_info = self.env.step(action)
                avg_achieved.append(env_info['achieved'])

                new_context = np.concatenate([
                    obs.reshape(1, -1),
                    action.reshape(1, -1),
                    next_obs.reshape(1, -1),
                    np.array(reward).reshape(1, -1)
                ],
                                             axis=1)

                raw_context.append(new_context)
                obs = next_obs.copy()
                if i > 1:
                    avg_reward += reward
                path_length += 1

            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            self.policy.context_encoder.infer_posterior(context)

            if i > 1:
                final_achieved.append(env_info['achieved'])

        avg_reward /= (self.num_evals - 2)
        if np.isscalar(env_info['achieved']):
            avg_achieved = np.mean(avg_achieved)
            final_achieved = np.mean(final_achieved)

        else:
            avg_achieved = np.stack(avg_achieved)
            avg_achieved = np.mean(avg_achieved, axis=0)

            final_achieved = np.stack(final_achieved)
            final_achieved = np.mean(final_achieved, axis=0)
        print(avg_reward)
        return avg_reward, (final_achieved.tolist(), self.env._goal.tolist())
コード例 #8
0
    def async_evaluate_test(self, goal):
        self.env.set_goal(goal)
        self.context_encoder.clear_z()

        avg_reward_list = []
        online_achieved_list = []

        raw_context = deque()
        for _ in range(self.num_evals):
            # Sample MDP indentity
            self.context_encoder.sample_z()
            inferred_mdp = self.context_encoder.z

            obs = self.env.reset()
            done = False
            path_length = 0
            avg_reward = 0.
            online_achieved = []
            while not done and path_length < self.max_path_length:
                action = self.select_actions(np.array(obs), inferred_mdp)
                next_obs, reward, done, env_info = self.env.step(action)
                achieved = env_info['achieved']
                online_achieved.append(np.arctan(achieved[1] / achieved[0]))
                if self.use_next_obs_in_context:
                    new_context = np.concatenate([
                        obs.reshape(1, -1),
                        action.reshape(1, -1),
                        next_obs.reshape(1, -1),
                        np.array(reward).reshape(1, -1)
                    ],
                                                 axis=1)
                else:
                    new_context = np.concatenate([
                        obs.reshape(1, -1),
                        action.reshape(1, -1),
                        np.array(reward).reshape(1, -1)
                    ],
                                                 axis=1)
                raw_context.append(new_context)
                obs = next_obs.copy()
                avg_reward += reward
                path_length += 1

            avg_reward_list.append(avg_reward)
            online_achieved = np.array(online_achieved)
            online_achieved_list.append([
                online_achieved.mean(),
                online_achieved.std(), self.env._goal
            ])

            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            self.context_encoder.infer_posterior(context)

        return online_achieved_list
コード例 #9
0
    def check_q_funct_estimate(self, paths):

        s0 = np.stack([path["observations"][0] for path in paths])
        s0 = ptu.from_numpy(s0)

        a0 = np.stack([path["actions"][0] for path in paths])
        a0 = ptu.from_numpy(a0)

        inferred_mdps = torch.repeat_interleave(self.trainer._inferred_mdp,
                                                s0.shape[0],
                                                dim=0)

        q_values = torch.min(
            self.trainer.qf1(s0, a0, inferred_mdps),
            self.trainer.qf2(s0, a0, inferred_mdps),
        )

        q_values = ptu.get_numpy(q_values)

        dicount_returns = []
        for path in paths:

            discount_cof = self.trainer.discount**np.arange(
                len(path["rewards"]))

            dicount_return = np.sum(path["rewards"].flatten() * discount_cof)
            dicount_returns.append(dicount_return)

        q_values_mean = np.mean(q_values)
        q_values_std = np.std(q_values)

        dicount_returns_mean = np.mean(dicount_returns)
        dicount_returns_std = np.std(dicount_returns)

        return dict(
            q_values_mean=q_values_mean,
            q_values_std=q_values_std,
            dicount_returns_mean=dicount_returns_mean,
            dicount_returns_std=dicount_returns_std,
        )
コード例 #10
0
    def select_actions(self, obs, raw_context):

        # Repeat the obs as what BCQ has done,
        # candidate_size here indicates how many
        # candidate actions we need.
        if len(raw_context) == 0:
            # In the beginning, the inferred_mdp is set to zero vector.
            inferred_mdp = ptu.zeros(
                (1, self.policy.mlp_encoder.encoder_latent_dim))
        else:
            # Construct the context from raw context
            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            inferred_mdp = self.policy.mlp_encoder(context)

        # obs = torch.cat([obs, inferred_mdp], dim=1)
        action = self.policy.select_action(obs, get_numpy(inferred_mdp))

        return action
コード例 #11
0
def _elem_or_tuple_to_variable(elem_or_tuple):
    if isinstance(elem_or_tuple, tuple):
        return tuple(
            _elem_or_tuple_to_variable(e) for e in elem_or_tuple
        )
    return ptu.from_numpy(elem_or_tuple).float()
コード例 #12
0
def torch_ify(np_array_or_other):
    if isinstance(np_array_or_other, np.ndarray):
        return ptu.from_numpy(np_array_or_other)
    else:
        return np_array_or_other
コード例 #13
0
def get_optimistic_exploration_action(ob_np,
                                      policy=None,
                                      qfs=None,
                                      hyper_params=None):

    #assert ob_np.ndim == 1

    beta_UB = hyper_params['beta_UB']
    delta = hyper_params['delta']

    #ob = ptu.from_numpy(ob_np)
    ob = {k: ptu.from_numpy(v[None]) for k, v in ob_np.items()}

    # Ensure that ob is not batched
    # assert len(list(ob.shape)) == 1

    _, pre_tanh_mu_T, _, _, std, _ = policy(ob)
    #print(pre_tanh_mu_T.shape)
    pre_tanh_mu_T = pre_tanh_mu_T[0]
    std = std[0]

    # Ensure that pretanh_mu_T is not batched
    assert len(list(pre_tanh_mu_T.shape)) == 1, pre_tanh_mu_T
    assert len(list(std.shape)) == 1

    pre_tanh_mu_T.requires_grad_()
    tanh_mu_T = torch.tanh(pre_tanh_mu_T)

    # Get the upper bound of the Q estimate
    args = [ob, torch.unsqueeze(tanh_mu_T, dim=0)
            ]  #list(torch.unsqueeze(i, dim=0) for i in (ob, tanh_mu_T))
    Q1 = qfs[0](*args)
    Q2 = qfs[1](*args)

    mu_Q = (Q1 + Q2) / 2.0

    sigma_Q = torch.abs(Q1 - Q2) / 2.0

    Q_UB = mu_Q + beta_UB * sigma_Q

    # Obtain the gradient of Q_UB wrt to a
    # with a evaluated at mu_t
    grad = torch.autograd.grad(Q_UB, pre_tanh_mu_T)
    grad = grad[0]

    assert grad is not None
    assert pre_tanh_mu_T.shape == grad.shape

    # Obtain Sigma_T (the covariance of the normal distribution)
    Sigma_T = torch.pow(std, 2)

    # The dividor is (g^T Sigma g) ** 0.5
    # Sigma is diagonal, so this works out to be
    # ( sum_{i=1}^k (g^(i))^2 (sigma^(i))^2 ) ** 0.5
    denom = torch.sqrt(torch.sum(torch.mul(torch.pow(grad, 2),
                                           Sigma_T))) + 10e-6

    # Obtain the change in mu
    mu_C = math.sqrt(2.0 * delta) * torch.mul(Sigma_T, grad) / denom

    assert mu_C.shape == pre_tanh_mu_T.shape

    mu_E = pre_tanh_mu_T + mu_C

    # Construct the tanh normal distribution and sample the exploratory action from it
    assert mu_E.shape == std.shape

    dist = TanhNormal(mu_E, std)

    ac = dist.sample()

    ac_np = ptu.get_numpy(ac)

    # mu_T_np = ptu.get_numpy(pre_tanh_mu_T)
    # mu_C_np = ptu.get_numpy(mu_C)
    # mu_E_np = ptu.get_numpy(mu_E)
    # dict(
    #     mu_T=mu_T_np,
    #     mu_C=mu_C_np,
    #     mu_E=mu_E_np
    # )

    # Return an empty dict, and do not log
    # stats for now
    return ac_np, {}
コード例 #14
0
    def train(self, batch, batch_idxes):
        """
        Unpack data from the batch
        """
        obs = batch['obs']
        actions = batch['actions']
        contexts = batch['contexts']

        num_tasks = batch_idxes.shape[0]

        gt.stamp('unpack_data_from_the_batch', unique=False)

        # Get the in_mdp_batch_size
        obs_dim = obs.shape[1]
        action_dim = actions.shape[1]
        in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0]
        num_trans_context = contexts.shape[0] // batch_idxes.shape[0]
        """
        Relabel the context batches for each training task
        """
        with torch.no_grad():
            contexts_obs_actions = contexts[:, :obs_dim + action_dim]

            manual_batched_rewards = self.reward_ensemble_predictor.forward_mul_device(
                contexts_obs_actions)

            relabeled_rewards = manual_batched_rewards.reshape(
                num_tasks, self.num_network_ensemble, contexts.shape[0])

            gt.stamp('reward_ensemble_forward', unique=False)

            manual_batched_next_obs = self.transition_ensemble_predictor.forward_mul_device(
                contexts_obs_actions)

            relabeled_next_obs = manual_batched_next_obs.reshape(
                num_tasks, self.num_network_ensemble, contexts.shape[0],
                obs_dim)

            gt.stamp('transition_ensemble_forward', unique=False)

            relabeled_rewards_mean = torch.mean(relabeled_rewards,
                                                dim=1).squeeze()
            relabeled_rewards_std = torch.std(relabeled_rewards,
                                              dim=1).squeeze()

            relabeled_next_obs_mean = torch.mean(relabeled_next_obs, dim=1)

            relabeled_next_obs_std = torch.std(relabeled_next_obs, dim=1)
            relabeled_next_obs_std = torch.mean(relabeled_next_obs_std, dim=-1)

            # Replace the predicted reward with ground truth reward for transitions
            # with ground truth reward inside the batch
            for i in range(num_tasks):

                relabeled_rewards_mean[i, i*num_trans_context: (i+1)*num_trans_context] \
                    = contexts[i*num_trans_context: (i+1)*num_trans_context, -1]

                relabeled_next_obs_mean[i, i*num_trans_context: (i+1)*num_trans_context, :] \
                    = contexts[i*num_trans_context: (i+1)*num_trans_context, obs_dim + action_dim : -1]

                if self.is_combine:
                    # Set the number to be larger than the self.std_threshold, so that
                    # they will initially be filtered out when producing the mask,
                    # which is conducive to the sampling.
                    relabeled_rewards_std[
                        i, i * num_trans_context:(i + 1) *
                        num_trans_context] = self.reward_std_threshold + 1.0
                    relabeled_next_obs_std[
                        i, i * num_trans_context:(i + 1) *
                        num_trans_context] = self.next_obs_std_threshold + 1.0
                else:
                    relabeled_rewards_std[i, i * num_trans_context:(i + 1) *
                                          num_trans_context] = 0.0
                    relabeled_next_obs_std[i, i * num_trans_context:(i + 1) *
                                           num_trans_context] = 0.0

            mask_reward = relabeled_rewards_std < self.reward_std_threshold
            mask_reward = mask_reward.type(torch.float)

            mask_next_obs = relabeled_next_obs_std < self.next_obs_std_threshold
            mask_next_obs = mask_next_obs.type(torch.float)

            mask = mask_reward * mask_next_obs
            mask = mask.type(torch.uint8)

            mask_from_the_other_tasks = mask.type(torch.uint8).clone()

            num_context_candidate_each_task = torch.sum(mask, dim=1)

            mask_list = []

            for i in range(num_tasks):

                assert mask[i].dim() == 1

                mask_nonzero = torch.nonzero(mask[i])
                mask_nonzero = mask_nonzero.flatten()

                mask_i = ptu.zeros_like(mask[i], dtype=torch.uint8)

                assert num_context_candidate_each_task[i].item(
                ) == mask_nonzero.shape[0]

                np_ind = np.random.choice(mask_nonzero.shape[0],
                                          num_trans_context,
                                          replace=False)

                ind = mask_nonzero[np_ind]

                mask_i[ind] = 1

                if self.is_combine:
                    # Combine the additional relabeledcontext transitions with
                    # the original context transitions with ground-truth rewards
                    mask_i[i * num_trans_context:(i + 1) *
                           num_trans_context] = 1
                    assert torch.sum(mask_i).item() == 2 * num_trans_context
                else:
                    assert torch.sum(mask_i).item() == num_trans_context

                mask_list.append(mask_i)

            mask = torch.cat(mask_list)
            mask = mask.type(torch.uint8)

            repeated_contexts = contexts.repeat(num_tasks, 1)
            context_without_next_obs_rewards = repeated_contexts[:, :obs_dim +
                                                                 action_dim]

            assert context_without_next_obs_rewards.shape[
                0] == relabeled_rewards_mean.reshape(-1, 1).shape[0]
            assert context_without_next_obs_rewards.shape[
                0] == relabeled_next_obs_mean.reshape(-1, obs_dim).shape[0]

            context_without_next_obs_rewards = context_without_next_obs_rewards[
                mask]

            context_next_obs = relabeled_next_obs_mean.reshape(-1,
                                                               obs_dim)[mask]

            context_rewards = relabeled_rewards_mean.reshape(-1, 1)[mask]

            fast_contexts = torch.cat((context_without_next_obs_rewards,
                                       context_next_obs, context_rewards),
                                      dim=1)

            fast_contexts = fast_contexts.reshape(num_tasks, -1,
                                                  contexts.shape[-1])

        gt.stamp('relabel_context_transitions', unique=False)
        """
        Obtain the targets
        """
        with torch.no_grad():
            # Sample z for each state
            z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device)

            # Each item in critic_weights is a list that has device count entries
            # each entry in the critic_weights[i] is a list that has num layer entries
            # each entry in the critic_weights[i][j] is a tensor of dim (num tasks // device count, layer input size, layer out size)
            # Similarly to the other weights and biases
            critic_weights, critic_biases, vae_weights, vae_biases, actor_weights, actor_biases = self.combined_bcq_policies

            # CRITIC
            obs_reshaped = obs.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            acs_reshaped = actions.reshape(len(batch_idxes), in_mdp_batch_size,
                                           -1)

            obs_acs_reshaped = torch.cat((obs_reshaped, acs_reshaped), dim=-1)

            target_q = batch_bcq(obs_acs_reshaped, critic_weights,
                                 critic_biases)
            target_q = target_q.reshape(-1)

            # VAE
            z_reshaped = z.reshape(len(batch_idxes), in_mdp_batch_size, -1)
            obs_z_reshaped = torch.cat((obs_reshaped, z_reshaped), dim=-1)

            tc = batch_bcq(obs_z_reshaped, vae_weights, vae_biases)
            tc = self.bcq_polices[0].vae.max_action * torch.tanh(tc)
            target_candidates = tc.reshape(-1, tc.shape[-1])

            # PERTURBATION
            tc_reshaped = target_candidates.reshape(len(batch_idxes),
                                                    in_mdp_batch_size, -1)

            obs_tc_reshaped = torch.cat((obs_reshaped, tc_reshaped), dim=-1)

            tp = batch_bcq(obs_tc_reshaped, actor_weights, actor_biases)
            tp = self.bcq_polices[0].actor.max_action * torch.tanh(tp)
            target_perturbations = tp.reshape(-1, tp.shape[-1])

        gt.stamp('get_the_targets', unique=False)
        """
        Compute the triplet loss
        """
        # ----------------------------------Vectorized-------------------------------------------
        self.context_encoder_optimizer.zero_grad()

        anchors = []
        positives = []
        negatives = []

        num_selected_list = []

        # Pair of task (i,j)
        # where no transitions from j is selected by the ensemble of task i
        exclude_tasks = []

        exclude_task_masks = []

        for i in range(num_tasks):
            # Compute the triplet loss for task i

            for j in range(num_tasks):

                if j != i:

                    # mask_for_task_j: (num_trans_context, )
                    # mask_from_the_other_tasks: (num_tasks, num_tasks * num_trans_context)
                    mask_for_task_j = mask_from_the_other_tasks[
                        i, j * num_trans_context:(j + 1) * num_trans_context]
                    num_selected = int(torch.sum(mask_for_task_j).item())

                    if num_selected == 0:
                        exclude_tasks.append((i, j))
                        exclude_task_masks.append(0)
                    else:
                        exclude_task_masks.append(1)

                    # context_trans_all: (num_trans_context, context_dim)
                    context_trans_all = contexts[j *
                                                 num_trans_context:(j + 1) *
                                                 num_trans_context]
                    # context_trans_all: (num_selected, context_dim)
                    context_trans_selected = context_trans_all[mask_for_task_j]

                    # relabel_reward_all: (num_trans_context, )
                    relabel_reward_all = relabeled_rewards_mean[
                        i, j * num_trans_context:(j + 1) * num_trans_context]
                    # relabel_reward_all: (num_selected, )
                    relabel_reward_selected = relabel_reward_all[
                        mask_for_task_j]
                    # relabel_reward_all: (num_selected, 1)
                    relabel_reward_selected = relabel_reward_selected.reshape(
                        -1, 1)

                    # relabel_next_obs_all: (num_trans_context, obs_dim)
                    relabel_next_obs_all = relabeled_next_obs_mean[
                        i, j * num_trans_context:(j + 1) * num_trans_context]
                    # relabel_next_obs_all: (num_selected, obs_dim)
                    relabel_next_obs_selected = relabel_next_obs_all[
                        mask_for_task_j]

                    # context_trans_selected_relabel: (num_selected, context_dim)
                    context_trans_selected_relabel = torch.cat([
                        context_trans_selected[:, :obs_dim + action_dim],
                        relabel_next_obs_selected, relabel_reward_selected
                    ],
                                                               dim=1)

                    # c_{i}
                    ind = np.random.choice(num_trans_context,
                                           num_selected,
                                           replace=False)

                    # Next 2 lines used for comparing to sequential version
                    # ind = ind_list[count]
                    # count += 1

                    # context_trans_task_i: (num_trans_context, context_dim)
                    context_trans_task_i = contexts[i *
                                                    num_trans_context:(i + 1) *
                                                    num_trans_context]
                    # context_trans_task_i: (num_selected, context_dim)
                    context_trans_task_i_sampled = context_trans_task_i[ind]

                    # Pad the contexts with 0 tensor
                    num_to_pad = num_trans_context - num_selected
                    # pad_zero_tensor: (num_to_pad, context_dim)
                    pad_zero_tensor = ptu.zeros(
                        (num_to_pad, context_trans_selected.shape[1]))

                    num_selected_list.append(num_selected)

                    # Dim: (1, num_trans_context, context_dim)
                    context_trans_selected = torch.cat(
                        [context_trans_selected, pad_zero_tensor], dim=0)
                    context_trans_selected_relabel = torch.cat(
                        [context_trans_selected_relabel, pad_zero_tensor],
                        dim=0)
                    context_trans_task_i_sampled = torch.cat(
                        [context_trans_task_i_sampled, pad_zero_tensor], dim=0)

                    anchors.append(context_trans_selected_relabel[None])
                    positives.append(context_trans_task_i_sampled[None])
                    negatives.append(context_trans_selected[None])

        # Dim: (num_tasks * (num_tasks - 1), num_trans_context, context_dim)
        anchors = torch.cat(anchors, dim=0)
        positives = torch.cat(positives, dim=0)
        negatives = torch.cat(negatives, dim=0)

        # input_contexts: (3 * num_tasks * (num_tasks - 1), num_trans_context, context_dim)
        input_contexts = torch.cat([anchors, positives, negatives], dim=0)

        # num_selected_pt: (num_tasks * (num_tasks - 1), )
        num_selected_pt = torch.from_numpy(np.array(num_selected_list))

        # num_selected_repeat: (3 * num_tasks * (num_tasks - 1), )
        num_selected_repeat = num_selected_pt.repeat(3)

        # z_means_vec, z_vars_vec: (3 * num_tasks * (num_tasks - 1), latent_dim)
        z_means_vec, z_vars_vec = self.context_encoder.infer_posterior_with_mean_var(
            input_contexts, num_trans_context, num_selected_repeat)

        # z_means_vec, z_vars_vec: (3, num_tasks * (num_tasks - 1), latent_dim)
        z_means_vec = z_means_vec.reshape(3, anchors.shape[0], -1)
        z_vars_vec = z_vars_vec.reshape(3, anchors.shape[0], -1)

        # Dim: (num_tasks * (num_tasks - 1), latent_dim)
        z_means_anchors, z_vars_anchors = z_means_vec[0], z_vars_vec[0]
        z_means_positives, z_vars_positives = z_means_vec[1], z_vars_vec[1]
        z_means_negatives, z_vars_negatives = z_means_vec[2], z_vars_vec[2]

        with_task_dist = compute_kl_div_diagonal(z_means_anchors,
                                                 z_vars_anchors,
                                                 z_means_positives,
                                                 z_vars_positives)
        across_task_dist = compute_kl_div_diagonal(z_means_anchors,
                                                   z_vars_anchors,
                                                   z_means_negatives,
                                                   z_vars_negatives)

        # Remove the triplet corresponding to
        # num selected equal 0
        exclude_task_masks = ptu.from_numpy(np.array(exclude_task_masks))

        with_task_dist = with_task_dist * exclude_task_masks
        across_task_dist = across_task_dist * exclude_task_masks

        unscaled_triplet_loss_vec = F.relu(with_task_dist - across_task_dist +
                                           self.triplet_margin)
        unscaled_triplet_loss_vec = torch.mean(unscaled_triplet_loss_vec)

        # assert unscaled_triplet_loss_vec is not nan
        assert (unscaled_triplet_loss_vec !=
                unscaled_triplet_loss_vec).any() is not True

        gt.stamp('get_triplet_loss', unique=False)

        unscaled_triplet_loss_vec.backward()

        check_grad_nan_nets(self.networks,
                            f'triplet: {unscaled_triplet_loss_vec}')

        gt.stamp('get_triplet_loss_gradient', unique=False)
        """
        Infer the context variables
        """
        # inferred_mdps = self.context_encoder(new_contexts)
        inferred_mdps = self.context_encoder(fast_contexts)
        inferred_mdps = torch.repeat_interleave(inferred_mdps,
                                                in_mdp_batch_size,
                                                dim=0)

        gt.stamp('infer_mdps', unique=False)
        """
        Obtain the KL loss
        """

        kl_div = self.context_encoder.compute_kl_div()

        kl_loss_each_task = self.kl_lambda * torch.sum(kl_div, dim=1)

        kl_loss = torch.sum(kl_loss_each_task)

        gt.stamp('get_kl_loss', unique=False)
        """
        Obtain the Q-function loss
        """
        self.Qs_optimizer.zero_grad()

        pred_q = self.Qs(obs, actions, inferred_mdps)
        pred_q = torch.squeeze(pred_q)

        qf_loss_each_task = (pred_q - target_q)**2
        qf_loss_each_task = qf_loss_each_task.reshape(num_tasks, -1)
        qf_loss_each_task = torch.mean(qf_loss_each_task, dim=1)

        qf_loss = torch.mean(qf_loss_each_task)

        gt.stamp('get_qf_loss', unique=False)

        (kl_loss + qf_loss).backward()

        check_grad_nan_nets(self.networks, 'kl q')

        gt.stamp('get_kl_qf_gradient', unique=False)

        self.Qs_optimizer.step()
        self.context_encoder_optimizer.step()

        gt.stamp('update_Qs_encoder', unique=False)
        """
        Obtain the candidate action and perturbation loss
        """

        self.vae_decoder_optimizer.zero_grad()
        self.perturbation_generator_optimizer.zero_grad()

        pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach())
        pred_perturbations = self.perturbation_generator(
            obs, target_candidates, inferred_mdps.detach())

        candidate_loss_each_task = (pred_candidates - target_candidates)**2

        # averaging over action dimension
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)
        candidate_loss_each_task = candidate_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1)

        candidate_loss = torch.mean(candidate_loss_each_task)

        perturbation_loss_each_task = (pred_perturbations -
                                       target_perturbations)**2

        # average over action dimension
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)
        perturbation_loss_each_task = perturbation_loss_each_task.reshape(
            num_tasks, in_mdp_batch_size)

        # average over action in each task
        perturbation_loss_each_task = torch.mean(perturbation_loss_each_task,
                                                 dim=1)

        perturbation_loss = torch.mean(perturbation_loss_each_task)

        gt.stamp('get_candidate_and_perturbation_loss', unique=False)

        (candidate_loss + perturbation_loss).backward()

        check_grad_nan_nets(self.networks, 'perb')

        gt.stamp('get_candidate_and_perturbation_gradient', unique=False)

        self.vae_decoder_optimizer.step()
        self.perturbation_generator_optimizer.step()

        for net in self.networks:

            for name, m in net.named_parameters():

                if (m != m).any():

                    print(net, name)
                    print(num_selected_list)
                    print(min(num_selected_list))

                    exit()

        gt.stamp('update_vae_perturbation', unique=False)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['qf_loss_each_task'] = ptu.get_numpy(
                qf_loss_each_task)

            self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss))
            self.eval_statistics['triplet_loss'] = np.mean(
                ptu.get_numpy(unscaled_triplet_loss_vec))
            self.eval_statistics['kl_loss_each_task'] = ptu.get_numpy(
                kl_loss_each_task)

            self.eval_statistics['candidate_loss'] = np.mean(
                ptu.get_numpy(candidate_loss))
            self.eval_statistics['candidate_loss_each_task'] = ptu.get_numpy(
                candidate_loss_each_task)

            self.eval_statistics['perturbation_loss'] = np.mean(
                ptu.get_numpy(perturbation_loss))
            self.eval_statistics[
                'perturbation_loss_each_task'] = ptu.get_numpy(
                    perturbation_loss_each_task)

            self.eval_statistics[
                'num_context_candidate_each_task'] = num_context_candidate_each_task
コード例 #15
0
 def set_param_values_np(self, param_values):
     torch_dict = OrderedDict()
     for key, tensor in param_values.items():
         torch_dict[key] = ptu.from_numpy(tensor)
     self.load_state_dict(torch_dict)
コード例 #16
0
    def collect_new_paths(
        self,
        max_path_length,
        num_steps,
        discard_incomplete_paths,
    ):

        self.context_encoder.clear_z()

        paths = []

        num_steps_collected = 0
        raw_context = deque()
        while num_steps_collected < num_steps:
            max_path_length_this_loop = min(  # Do not go over num_steps
                max_path_length,
                num_steps - num_steps_collected,
            )

            # Sample MDP indentity
            self.context_encoder.sample_z()
            inferred_mdp = self.context_encoder.z

            path_length = 0
            observations = []
            actions = []
            rewards = []
            terminals = []
            agent_infos = []
            env_infos = []

            obs = self.env.reset()
            done = False

            while not done and path_length < max_path_length_this_loop:
                action = self.select_actions(np.array(obs), inferred_mdp)

                next_obs, reward, done, _ = self.env.step(action)

                if self.use_next_obs_in_context:
                    new_context = np.concatenate([
                        obs.reshape(1, -1),
                        action.reshape(1, -1),
                        next_obs.reshape(1, -1),
                        np.array(reward).reshape(1, -1)
                    ],
                                                 axis=1)
                else:
                    assert False

                observations.append(obs)
                rewards.append(reward)
                terminals.append(done)
                actions.append(action)
                agent_infos.append(-1)
                env_infos.append(-1)
                path_length += 1

                raw_context.append(new_context)
                obs = next_obs.copy()

            context = from_numpy(np.concatenate(raw_context, axis=0))[None]
            self.context_encoder.infer_posterior(context)

            actions = np.array(actions)
            if len(actions.shape) == 1:
                actions = np.expand_dims(actions, 1)

            observations = np.array(observations)
            if len(observations.shape) == 1:
                observations = np.expand_dims(observations, 1)
                next_obs = np.array([next_obs])
            next_observations = np.vstack(
                (observations[1:, :], np.expand_dims(next_obs, 0)))

            path = dict(
                observations=observations,
                actions=actions,
                rewards=np.array(rewards).reshape(-1, 1),
                next_observations=next_observations,
                terminals=np.array(terminals).reshape(-1, 1),
                agent_infos=agent_infos,
                env_infos=env_infos,
            )

            path_len = len(path['actions'])
            if (
                    # incomplete path
                    path_len != max_path_length and

                    # that did not end in a terminal state
                    not path['terminals'][-1] and

                    # and we should discard such path
                    discard_incomplete_paths):
                break
            num_steps_collected += path_len
            paths.append(path)
        self._num_paths_total += len(paths)
        self._num_steps_total += num_steps_collected
        self._epoch_paths.extend(paths)

        return paths, inferred_mdp