Example #1
0
File: demos.py Project: qxcv/mtil
 def __len__(self):
     first_val = next(iter(self.tensor_dict.values()))
     if isinstance(first_val, torch.Tensor):
         return first_val.size(0)
     elif isinstance(first_val, tuple):
         sizes = set()
         tree_map(lambda t: sizes.add(t.size(0)), first_val)
         assert len(sizes) == 1, sizes
         return next(iter(sizes))
     raise TypeError(f"can't handle value of type '{type(first_val)}'")
Example #2
0
File: demos.py Project: qxcv/mtil
    def __init__(self, tensor_dict):
        # need at least one tensor
        assert len(tensor_dict) > 0

        # make sure batch size is uniform
        batch_sizes = set()
        tree_map(lambda t: batch_sizes.add(t.size(0)), tensor_dict)
        assert len(batch_sizes) == 1, batch_sizes

        self.tensor_dict = tensor_dict
Example #3
0
File: demos.py Project: qxcv/mtil
def insert_task_source_ids(obs_seq, task_id, variant_id, source_id):
    # figuring out the length of the sequence requires a bit of judo because
    # the observations are stored as namedarraytuples
    lens = []
    tree_map(lambda t: lens.append(len(t)), obs_seq.obs)
    nobs, = set(lens)

    task_id_array = np.full((nobs, ), task_id, dtype='int64')
    variant_id_array = np.full((nobs, ), variant_id, dtype='int64')
    source_id_array = np.full((nobs, ), source_id, dtype='int64')
    eio_arr = EnvIDObsArray(observation=obs_seq.obs,
                            task_id=task_id_array,
                            variant_id=variant_id_array,
                            source_id=source_id_array)
    return obs_seq._replace(obs=eio_arr)
Example #4
0
File: mtgail.py Project: qxcv/mtil
def hacky_interpolate(structure, eps, sub_batch_size):
    def inner_map(tens):
        tens_l = tens[:sub_batch_size]
        tens_r = tens[sub_batch_size:]
        if tens.ndim == 4 and tens.dtype in (torch.uint8, torch.float):
            # Float interpolation is easy. Otherwise we properly interpolate
            # byte image by (implicitly) converting to float, then back to
            # byte.
            eps_exp = eps[:, None, None, None]
            interp = eps_exp * tens_l + (1 - eps_exp) * tens_r
            if tens.type == torch.uint8:
                interp = interp.to(torch.uint8)
        elif tens.ndim == 1 and not torch.is_floating_point(tens):
            # "interpolate" some integers (assume action label or something, so
            # the output must be discrete)

            # FIXME(sam): what we really want to do is interpolate the
            # one-hot representation or something, but instead of that I'm
            # just going to do this. (whatever, Lipschitzness of discrete
            # functions doesn't make sense anyway.)

            tens_elems = []
            for eps_i, elem_l, elem_r in zip(eps, tens_l, tens_r):
                tens_elems.append(elem_l if eps_i > 0.5 else elem_r)
            interp = torch.stack(tens_elems, dim=0)

        else:
            raise ValueError("cannot handle shape/type of tensor", tens)

        return interp

    new_structure = tree_map(inner_map, structure)

    return new_structure
Example #5
0
File: demos.py Project: qxcv/mtil
def make_tensor_dict_dataset(demo_trajs_by_env, omit_noop=False):
    """Re-format multi-task trajectories into a Torch dataset of dicts."""

    cpu_dev = torch.device("cpu")

    # make big list of trajectories in a deterministic order
    all_obs = []
    all_acts = []
    for env_name in sorted(demo_trajs_by_env.keys()):
        demo_trajs = demo_trajs_by_env[env_name]
        n_samples = 0
        for traj in demo_trajs:
            # The observation trajectories are one elem longer than the act
            # trajectories because they include terminal obs. We lop that off
            # here.
            all_obs.append(
                tree_map(lambda t: torch.as_tensor(t, device=cpu_dev)[:-1],
                         traj.obs))
            all_acts.append(torch.as_tensor(traj.acts, device=cpu_dev))
            n_samples += len(traj.acts)

        # weight things inversely proportional to their frequency
        assert n_samples > 0, demo_trajs

    # join together trajectories into Torch dataset
    all_obs = tree_map(lambda *t: torch.cat(t), *all_obs)
    all_acts = torch.cat(all_acts)

    if omit_noop:
        # omit action 0 (helps avoid the "agent does nothing initially" problem
        # for MAGICAL)
        valid_inds = torch.squeeze(torch.nonzero(all_acts), 1)
        all_obs = tree_map(lambda t: t[valid_inds], all_obs)
        all_acts = all_acts[valid_inds]

    dataset = DictTensorDataset({
        'obs': all_obs,
        'acts': all_acts,
    })

    return dataset
Example #6
0
def do_training_mt(loader, model, opt, dev, aug_model, min_bc_module,
                   n_batches):
    # @torch.jit.script
    def do_loss_forward_back(obs_batch_obs, obs_batch_task, obs_batch_var,
                             obs_batch_source, acts_batch):
        # we don't use the value output
        logits_flat, _ = model(obs_batch_obs, task_ids=obs_batch_task)
        losses = F.cross_entropy(logits_flat,
                                 acts_batch.long(),
                                 reduction='none')
        if min_bc_module is not None:
            # weight using a model-dependent strategy
            mbc_weights = min_bc_module(obs_batch_task, obs_batch_var,
                                        obs_batch_source)
            assert mbc_weights.shape == losses.shape, (mbc_weights.shape,
                                                       losses.shape)
            loss = (losses * mbc_weights).sum()
        else:
            # no weighting
            loss = losses.mean()
        loss.backward()
        return losses.detach().cpu().numpy()

    # make sure we're in train mode
    model.train()

    # for logging
    loss_ewma = None
    losses = []
    per_task_losses = collections.defaultdict(lambda: [])
    progress = ProgBarCounter(n_batches)
    inf_batch_iter = repeat_dataset(loader)
    ctr_batch_iter = zip(range(1, n_batches), inf_batch_iter)
    for batches_done, loader_batch in ctr_batch_iter:
        # (task_ids_batch, obs_batch, acts_batch)
        # copy to GPU
        obs_batch = loader_batch['obs']
        acts_batch = loader_batch['acts']
        # reminder: attributes are .observation, .task_id, .variant_id
        obs_batch = tree_map(lambda t: t.to(dev), obs_batch)
        acts_batch = acts_batch.to(dev)

        if aug_model is not None:
            # apply augmentations
            obs_batch = obs_batch._replace(
                observation=aug_model(obs_batch.observation))

        # compute loss & take opt step
        opt.zero_grad()
        batch_losses = do_loss_forward_back(obs_batch.observation,
                                            obs_batch.task_id,
                                            obs_batch.variant_id,
                                            obs_batch.source_id, acts_batch)
        opt.step()

        # for logging
        progress.update(batches_done)
        f_loss = np.mean(batch_losses)
        loss_ewma = f_loss if loss_ewma is None \
            else 0.9 * loss_ewma + 0.1 * f_loss
        losses.append(f_loss)

        # also track separately for each task
        tv_ids = torch.stack((obs_batch.task_id, obs_batch.variant_id), axis=1)
        np_tv_ids = tv_ids.cpu().numpy()
        assert len(np_tv_ids.shape) == 2 and np_tv_ids.shape[1] == 2, \
            np_tv_ids.shape
        for tv_id in np.unique(np_tv_ids, axis=0):
            tv_mask = np.all(np_tv_ids == tv_id[None], axis=-1)
            rel_losses = batch_losses[tv_mask]
            if len(rel_losses) > 0:
                task_id, variant_id = tv_id
                per_task_losses[(task_id, variant_id)] \
                    .append(np.mean(rel_losses))

    progress.stop()

    return loss_ewma, losses, per_task_losses
Example #7
0
File: mtgail.py Project: qxcv/mtil
    def optim_disc(self, itr, n_itr, samples):
        # TODO: refactor this method. Makes sense to split code that sets up
        # the replay buffer(s) from code that sets up each batch (and from the
        # code that computes losses and records metrics, etc.)

        # store new samples in replay buffer
        if self.pol_replay_buffer is None:
            self.pol_replay_buffer = DiscrimTrainBuffer(
                self.buffer_num_samples, samples)
        # keep ONLY the demo env samples
        sample_variant_ids = samples.env.observation.variant_id
        train_variant_mask = sample_variant_ids == 0
        # check that each batch index is "pure", in sense that e.g. all
        # elements at index k are always for the same task ID
        assert (train_variant_mask[:1] == train_variant_mask).all(), \
            train_variant_mask
        filtered_samples = samples[:, train_variant_mask[0]]
        self.pol_replay_buffer.append_samples(filtered_samples)

        if self.xfer_adv_model is not None:
            # if we have an adversarial domain adaptation model for transfer
            # learning, then we also keep samples that *don't* come from the
            # train variant so we can use them for the transfer loss
            if self.xfer_replay_buffer is None:
                # second replay buffer for off-task samples
                self.xfer_replay_buffer = DiscrimTrainBuffer(
                    self.buffer_num_samples, samples)
            xfer_variant_mask = ~train_variant_mask
            assert torch.any(xfer_variant_mask), \
                "xfer_adv_weight>0 supplied, but no xfer variants in batch?"
            assert (xfer_variant_mask[:1] == xfer_variant_mask).all()
            filtered_samples_xfer = samples[:, xfer_variant_mask[0]]
            self.xfer_replay_buffer.append_samples(filtered_samples_xfer)

        if self.final_layer_only_mode:
            print("Snapping final layer to its optimal value")
            _optimal_final_layer(
                model=self.model,
                expert_batch_iter=self.expert_batch_iter,
                pol_replay_buffer=self.pol_replay_buffer,
                batch_size=self.batch_size,
                n_eval_batches=int(
                    np.ceil(self.final_layer_only_mode_n_samples /
                            self.batch_size)),
                aug_model=self.aug_model,
                device=self.dev)

        # switch to train mode before taking any steps
        self.model.train()
        info_dicts = []
        for _ in range(self.updates_per_itr):
            self.opt.zero_grad()

            expert_data = next(self.expert_batch_iter)
            expert_obs = expert_data['obs']
            expert_acts = expert_data['acts']
            # grep rlpyt source for "SamplesFromReplay" to see what fields
            # pol_replay_samples has
            sub_batch_size = self.batch_size // 2
            pol_replay_samples = self.pol_replay_buffer.sample_batch(
                sub_batch_size)
            pol_replay_samples = torchify_buffer(pol_replay_samples)
            novice_obs = pol_replay_samples.all_observation

            if self.xfer_adv_model:
                # add a bunch of of domain transfer samples
                xfer_replay_samples = self.xfer_replay_buffer.sample_batch(
                    sub_batch_size)
                xfer_replay_samples = torchify_buffer(xfer_replay_samples)
                xfer_replay_obs = xfer_replay_samples.all_observation
                all_obs = tree_map(
                    lambda *args: torch.cat(args, 0).to(self.dev), expert_obs,
                    novice_obs, xfer_replay_obs)
                all_acts = torch.cat([expert_acts.to(torch.int64),
                                      pol_replay_samples.all_action,
                                      xfer_replay_samples.all_action],
                                     dim=0) \
                    .to(self.dev)
            else:
                all_obs = tree_map(
                    lambda *args: torch.cat(args, 0).to(self.dev), expert_obs,
                    novice_obs)
                all_acts = torch.cat([expert_acts.to(torch.int64),
                                      pol_replay_samples.all_action],
                                     dim=0) \
                    .to(self.dev)

            if self.aug_model is not None:
                # augmentations
                aug_frames = self.aug_model(all_obs.observation)
                all_obs = all_obs._replace(observation=aug_frames)

            make_ones = functools.partial(torch.ones,
                                          dtype=torch.float32,
                                          device=self.dev)
            make_zeros = functools.partial(torch.zeros,
                                           dtype=torch.float32,
                                           device=self.dev)
            is_real_label = torch.cat(
                [
                    # expert samples
                    make_ones(sub_batch_size),
                    # novice samples
                    make_zeros(sub_batch_size),
                ],
                dim=0)

            if self.xfer_adv_model:
                # apply domain transfer loss to the transfer samples, then
                # remove them
                logits_all, disc_feats_all = self.model(all_obs,
                                                        all_acts,
                                                        return_feats=True)
                # cut the expert samples out of the discriminator transfer
                # objective
                disc_feats_xfer = disc_feats_all[sub_batch_size:]
                xfer_labels = torch.cat(
                    [make_ones(sub_batch_size),
                     make_zeros(sub_batch_size)],
                    dim=0)
                xfer_loss, xfer_acc = self.xfer_adv_model(
                    disc_feats_xfer, xfer_labels)
                # cut the transfer env samples out of the logits
                logits = logits_all[:-sub_batch_size]
            else:
                if self.final_layer_only_mode:
                    # cutting out gradients for the main model evaluation
                    # hopefully makes this a bit less memory-intensive
                    with torch.no_grad():
                        logits = self.model(all_obs, all_acts)
                else:
                    logits = self.model(all_obs, all_acts)

            if self.use_wgan:
                # Now assign label +1 for novice, -1 for expert. This means we
                # are trying to push novice scores down, and expert scores up
                # (I don't know whether this matches the original paper).
                pm_labels = 1 - 2 * is_real_label
                wgan_loss = main_loss = torch.mean(pm_labels * logits)
            else:
                # GAIL discriminator *objective* is E_fake[log D(s,a)] +
                # E_expert[log(1-D(s,a))]. You actually want to maximise this;
                # in reality the "loss" to be minimised is -E_fake[log D(s,a)]
                # - E_expert[log(1-D(s,a))].
                #
                # binary_cross_entropy_with_logits computes -labels *
                # log(sigmoid(logits)) - (1 - labels) * log(1-sigmoid(logits))
                # (per PyTorch docs). In other words: -y * log D(s,a) - (1 - y)
                # * log(1 - D(s, a))
                #
                # Hence, GAIL is like logistic regression with label 1 for the
                # novice and 0 for the expert. This is kind of weird, because
                # you actually want to *minimise* the discriminator's output.
                # Indeed, in the actual implementation, they flip this & use 1
                # for the expert and 0 for the novice.

                # In light of all the above, I'm using the OPPOSITE convention
                # to the paper, but the same convention as the implementation.
                # To wit:
                #
                # - 1 is the *expert* label, and high = more expert-like
                # - 0 is the *novice* label, and low = less expert-like
                xent_loss = main_loss = F.binary_cross_entropy_with_logits(
                    logits, is_real_label, reduction='mean')

            loss = main_loss

            if self.gp_weight:
                eps = torch.rand((sub_batch_size, )).to(self.dev)
                preproc_obs, preproc_acts, _ = self.model.preproc_obs_acts(
                    all_obs, all_acts)
                interp_obs = hacky_interpolate(preproc_obs, eps,
                                               sub_batch_size)
                interp_acts = hacky_interpolate(preproc_acts, eps,
                                                sub_batch_size)
                interp_task_ids = hacky_interpolate(all_obs.task_id, eps,
                                                    sub_batch_size)

                # using the strategy from
                # https://github.com/caogang/wgan-gp/blob/master/gan_mnist.py
                # (also we only do this w.r.t images, not scalar inputs)
                interp_obs_var = torch.autograd.Variable(interp_obs,
                                                         requires_grad=True)
                interp_logits, _, _ = self.model.forward_no_preproc(
                    interp_obs_var, interp_acts, interp_task_ids)
                start_grad = torch.ones((sub_batch_size, ), device=self.dev)
                # FIXME(sam): I suspect retrain_graph and grad_outputs are not
                # actually required. Should dig deeper some time.
                grads_wrt_inputs, = torch.autograd.grad(
                    outputs=interp_logits,
                    inputs=[interp_obs_var],
                    grad_outputs=start_grad,
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)
                grads_wrt_inputs_flat = torch.flatten(grads_wrt_inputs,
                                                      start_dim=1)
                grad_norms = torch.sqrt(
                    torch.sum(grads_wrt_inputs_flat**2, dim=1))
                grad_penalty = torch.mean((grad_norms - 1.0)**2)
                loss = loss + self.gp_weight * grad_penalty

            if self.xfer_adv_model:
                if self.xfer_disc_anneal:
                    progress = min(1, max(0, itr / float(n_itr)))
                    xfer_adv_weight = progress * self.xfer_adv_weight
                else:
                    xfer_adv_weight = self.xfer_adv_weight
                loss = loss + xfer_adv_weight * xfer_loss

            if self.final_layer_only_mode:
                print("Skipping actual optimiser step")
            else:
                loss.backward()
                self.opt.step()

            if self.use_wgan or self.final_layer_only_mode:
                # center logits so that 0 = uncertain
                stat_logits = logits - logits.mean()
            else:
                # for vanilla GAN, it's fine not to
                stat_logits = logits

            # for logging; we'll average theses later
            info_dict = _compute_gail_stats(stat_logits, is_real_label,
                                            self.use_wgan)
            info_dict['discLoss'] = loss.item()
            if self.use_wgan:
                info_dict['discXentLoss'] = float('nan')
                info_dict['discWGANLoss'] = wgan_loss.item()
            else:
                info_dict['discXentLoss'] = xent_loss.item()
                info_dict['discWGANLoss'] = float('nan')
            if self.gp_weight:
                info_dict['discGPGradNorm'] = grad_norms.mean().item()
                info_dict['discGPLoss'] = grad_penalty.item()
            else:
                info_dict['discGPGradNorm'] = float('nan')
                info_dict['discGPLoss'] = float('nan')
            if self.xfer_adv_model:
                info_dict['xferLoss'] = xfer_loss.item()
                info_dict['xferAcc'] = xfer_acc.item()
            else:
                info_dict['xferLoss'] = 0.0
                info_dict['xferAcc'] = 0.0
            info_dicts.append(info_dict)

        # switch back to eval mode
        self.model.eval()

        opt_info = GAILInfo(**{
            k: np.mean([d[k] for d in info_dicts])
            for k in info_dicts[0].keys()
        })

        return opt_info
Example #8
0
File: mtgail.py Project: qxcv/mtil
def _optimal_final_layer(model, expert_batch_iter, pol_replay_buffer,
                         batch_size, n_eval_batches, aug_model, device):
    """Snap the final layer weights to their optimal values under
    apprenticeship learning loss. Implicitly constrained so that ||w||<=1,
    ||b||=0 (||b||=0 is fine for app. learning)."""
    assert isinstance(model.mt_logits, SingleTaskAffineLayer), \
        "right now this only works with the single-task affine layer"
    assert not model.mt_logits.use_sn, \
        "currently this is also incompatible with spectral norm"

    all_feats_expert = []
    all_feats_novice = []

    with torch.no_grad():
        # we need the model in eval mode for this
        old_training = model.training
        model.eval()

        for _ in range(n_eval_batches):
            expert_data = next(expert_batch_iter)
            expert_obs = expert_data['obs']
            expert_acts = expert_data['acts']
            sub_batch_size = batch_size // 2
            pol_replay_samples = pol_replay_buffer.sample_batch(sub_batch_size)
            pol_replay_samples = torchify_buffer(pol_replay_samples)
            novice_obs = pol_replay_samples.all_observation

            all_obs = tree_map(lambda *args: torch.cat(args, 0).to(device),
                               expert_obs, novice_obs)
            all_acts = torch.cat(
                [expert_acts.to(torch.int64), pol_replay_samples.all_action],
                dim=0).to(device)

            if aug_model is not None:
                # augmentations
                aug_frames = aug_model(all_obs.observation)
                all_obs = all_obs._replace(observation=aug_frames)

            _, disc_feats_unproc = model(all_obs, all_acts, return_feats=True)
            # FIXME(sam): this should be done *inside* the model!
            disc_feats_all = model.postproc(disc_feats_unproc)

            disc_feats_expert = disc_feats_all[:sub_batch_size]
            disc_feats_novice = disc_feats_all[sub_batch_size:]

            all_feats_expert.extend(disc_feats_expert.cpu().numpy())
            all_feats_novice.extend(disc_feats_novice.cpu().numpy())

        # put us back into the right mode now that we're done with eval
        model.train(old_training)

    # now set the weights appropriately
    expert_mean = np.mean(all_feats_expert, axis=0)
    novice_mean = np.mean(all_feats_novice, axis=0)
    weight_diff = expert_mean - novice_mean
    weight_diff_norm = np.linalg.norm(weight_diff)
    opt_weights = weight_diff / max(weight_diff_norm, 1e-5)

    # lin_layer.bias is of shape [out_features], while lin_layer.weight is of
    # shape [out_features, in_features].
    lin_layer = model.mt_logits.lin_layer
    # zero out the bias; it won't make a difference to the (expert mean -
    # novice mean) objective
    lin_layer.bias[:] = 0.0
    # the extra None is so that we have a [1,feature_dim]-shaped tensor
    lin_layer.weight[:] = lin_layer.weight.new_tensor(opt_weights[None])
Example #9
0
File: mtgail.py Project: qxcv/mtil
 def _to_dev(self, item):
     return tree_map(lambda *args: torch.cat(args, 0).to(self.dev), item)
Example #10
0
    def evaluate(self, obs_tuple, act_tensor, update_stats=True):
        # put model into eval mode if necessary
        old_training = self.reward_model.training
        if old_training:
            self.reward_model.eval()

        with torch.no_grad():
            # flatten observations & actions
            obs_image = obs_tuple.observation
            old_dev = obs_image.device
            lead_dim, T, B, _ = infer_leading_dims(obs_image, self.obs_dims)
            # use tree_map so we are able to handle the namedtuple directly
            obs_flat = tree_map(
                lambda t: t.view((T * B, ) + t.shape[lead_dim:]), obs_tuple)
            act_flat = act_tensor.view((T * B, ) + act_tensor.shape[lead_dim:])

            # now evaluate one batch at a time
            reward_tensors = []
            for b_start in range(0, T * B, self.batch_size):
                obs_batch = obs_flat[b_start:b_start + self.batch_size]
                act_batch = act_flat[b_start:b_start + self.batch_size]
                dev_obs = tree_map(lambda t: t.to(self.dev), obs_batch)
                dev_acts = act_batch.to(self.dev)
                dev_reward = self.reward_model(dev_obs, dev_acts)
                reward_tensors.append(dev_reward.to(old_dev))

            # join together the batch results
            new_reward_flat = torch.cat(reward_tensors, 0)
            new_reward = restore_leading_dims(new_reward_flat, lead_dim, T, B)
            task_ids = obs_tuple.task_id
            assert new_reward.shape == task_ids.shape

        # put back into training mode if necessary
        if old_training:
            self.reward_model.train(old_training)

        # normalise if necessary
        if self.normalise:
            mus = []
            stds = []
            for task_id, averager in enumerate(self.rew_running_averages):
                if update_stats:
                    id_sub = task_ids.view((-1, )) == task_id
                    if not torch.any(id_sub):
                        continue
                    rew_sub = new_reward.view((-1, ))[id_sub]
                    averager.update(rew_sub)

                mus.append(averager.mean.item())
                stds.append(averager.std.item())

            mus = new_reward.new_tensor(mus)
            stds = new_reward.new_tensor(stds)
            denom = torch.max(stds.new_tensor(1e-3), stds / self.target_std)

            denom_sub = denom[task_ids]
            mu_sub = mus[task_ids]

            # only bother applying result if we've actually seen an update
            # before (otherwise reward will be insane)
            new_reward = (new_reward - mu_sub) / denom_sub

        return new_reward