def __len__(self): first_val = next(iter(self.tensor_dict.values())) if isinstance(first_val, torch.Tensor): return first_val.size(0) elif isinstance(first_val, tuple): sizes = set() tree_map(lambda t: sizes.add(t.size(0)), first_val) assert len(sizes) == 1, sizes return next(iter(sizes)) raise TypeError(f"can't handle value of type '{type(first_val)}'")
def __init__(self, tensor_dict): # need at least one tensor assert len(tensor_dict) > 0 # make sure batch size is uniform batch_sizes = set() tree_map(lambda t: batch_sizes.add(t.size(0)), tensor_dict) assert len(batch_sizes) == 1, batch_sizes self.tensor_dict = tensor_dict
def insert_task_source_ids(obs_seq, task_id, variant_id, source_id): # figuring out the length of the sequence requires a bit of judo because # the observations are stored as namedarraytuples lens = [] tree_map(lambda t: lens.append(len(t)), obs_seq.obs) nobs, = set(lens) task_id_array = np.full((nobs, ), task_id, dtype='int64') variant_id_array = np.full((nobs, ), variant_id, dtype='int64') source_id_array = np.full((nobs, ), source_id, dtype='int64') eio_arr = EnvIDObsArray(observation=obs_seq.obs, task_id=task_id_array, variant_id=variant_id_array, source_id=source_id_array) return obs_seq._replace(obs=eio_arr)
def hacky_interpolate(structure, eps, sub_batch_size): def inner_map(tens): tens_l = tens[:sub_batch_size] tens_r = tens[sub_batch_size:] if tens.ndim == 4 and tens.dtype in (torch.uint8, torch.float): # Float interpolation is easy. Otherwise we properly interpolate # byte image by (implicitly) converting to float, then back to # byte. eps_exp = eps[:, None, None, None] interp = eps_exp * tens_l + (1 - eps_exp) * tens_r if tens.type == torch.uint8: interp = interp.to(torch.uint8) elif tens.ndim == 1 and not torch.is_floating_point(tens): # "interpolate" some integers (assume action label or something, so # the output must be discrete) # FIXME(sam): what we really want to do is interpolate the # one-hot representation or something, but instead of that I'm # just going to do this. (whatever, Lipschitzness of discrete # functions doesn't make sense anyway.) tens_elems = [] for eps_i, elem_l, elem_r in zip(eps, tens_l, tens_r): tens_elems.append(elem_l if eps_i > 0.5 else elem_r) interp = torch.stack(tens_elems, dim=0) else: raise ValueError("cannot handle shape/type of tensor", tens) return interp new_structure = tree_map(inner_map, structure) return new_structure
def make_tensor_dict_dataset(demo_trajs_by_env, omit_noop=False): """Re-format multi-task trajectories into a Torch dataset of dicts.""" cpu_dev = torch.device("cpu") # make big list of trajectories in a deterministic order all_obs = [] all_acts = [] for env_name in sorted(demo_trajs_by_env.keys()): demo_trajs = demo_trajs_by_env[env_name] n_samples = 0 for traj in demo_trajs: # The observation trajectories are one elem longer than the act # trajectories because they include terminal obs. We lop that off # here. all_obs.append( tree_map(lambda t: torch.as_tensor(t, device=cpu_dev)[:-1], traj.obs)) all_acts.append(torch.as_tensor(traj.acts, device=cpu_dev)) n_samples += len(traj.acts) # weight things inversely proportional to their frequency assert n_samples > 0, demo_trajs # join together trajectories into Torch dataset all_obs = tree_map(lambda *t: torch.cat(t), *all_obs) all_acts = torch.cat(all_acts) if omit_noop: # omit action 0 (helps avoid the "agent does nothing initially" problem # for MAGICAL) valid_inds = torch.squeeze(torch.nonzero(all_acts), 1) all_obs = tree_map(lambda t: t[valid_inds], all_obs) all_acts = all_acts[valid_inds] dataset = DictTensorDataset({ 'obs': all_obs, 'acts': all_acts, }) return dataset
def do_training_mt(loader, model, opt, dev, aug_model, min_bc_module, n_batches): # @torch.jit.script def do_loss_forward_back(obs_batch_obs, obs_batch_task, obs_batch_var, obs_batch_source, acts_batch): # we don't use the value output logits_flat, _ = model(obs_batch_obs, task_ids=obs_batch_task) losses = F.cross_entropy(logits_flat, acts_batch.long(), reduction='none') if min_bc_module is not None: # weight using a model-dependent strategy mbc_weights = min_bc_module(obs_batch_task, obs_batch_var, obs_batch_source) assert mbc_weights.shape == losses.shape, (mbc_weights.shape, losses.shape) loss = (losses * mbc_weights).sum() else: # no weighting loss = losses.mean() loss.backward() return losses.detach().cpu().numpy() # make sure we're in train mode model.train() # for logging loss_ewma = None losses = [] per_task_losses = collections.defaultdict(lambda: []) progress = ProgBarCounter(n_batches) inf_batch_iter = repeat_dataset(loader) ctr_batch_iter = zip(range(1, n_batches), inf_batch_iter) for batches_done, loader_batch in ctr_batch_iter: # (task_ids_batch, obs_batch, acts_batch) # copy to GPU obs_batch = loader_batch['obs'] acts_batch = loader_batch['acts'] # reminder: attributes are .observation, .task_id, .variant_id obs_batch = tree_map(lambda t: t.to(dev), obs_batch) acts_batch = acts_batch.to(dev) if aug_model is not None: # apply augmentations obs_batch = obs_batch._replace( observation=aug_model(obs_batch.observation)) # compute loss & take opt step opt.zero_grad() batch_losses = do_loss_forward_back(obs_batch.observation, obs_batch.task_id, obs_batch.variant_id, obs_batch.source_id, acts_batch) opt.step() # for logging progress.update(batches_done) f_loss = np.mean(batch_losses) loss_ewma = f_loss if loss_ewma is None \ else 0.9 * loss_ewma + 0.1 * f_loss losses.append(f_loss) # also track separately for each task tv_ids = torch.stack((obs_batch.task_id, obs_batch.variant_id), axis=1) np_tv_ids = tv_ids.cpu().numpy() assert len(np_tv_ids.shape) == 2 and np_tv_ids.shape[1] == 2, \ np_tv_ids.shape for tv_id in np.unique(np_tv_ids, axis=0): tv_mask = np.all(np_tv_ids == tv_id[None], axis=-1) rel_losses = batch_losses[tv_mask] if len(rel_losses) > 0: task_id, variant_id = tv_id per_task_losses[(task_id, variant_id)] \ .append(np.mean(rel_losses)) progress.stop() return loss_ewma, losses, per_task_losses
def optim_disc(self, itr, n_itr, samples): # TODO: refactor this method. Makes sense to split code that sets up # the replay buffer(s) from code that sets up each batch (and from the # code that computes losses and records metrics, etc.) # store new samples in replay buffer if self.pol_replay_buffer is None: self.pol_replay_buffer = DiscrimTrainBuffer( self.buffer_num_samples, samples) # keep ONLY the demo env samples sample_variant_ids = samples.env.observation.variant_id train_variant_mask = sample_variant_ids == 0 # check that each batch index is "pure", in sense that e.g. all # elements at index k are always for the same task ID assert (train_variant_mask[:1] == train_variant_mask).all(), \ train_variant_mask filtered_samples = samples[:, train_variant_mask[0]] self.pol_replay_buffer.append_samples(filtered_samples) if self.xfer_adv_model is not None: # if we have an adversarial domain adaptation model for transfer # learning, then we also keep samples that *don't* come from the # train variant so we can use them for the transfer loss if self.xfer_replay_buffer is None: # second replay buffer for off-task samples self.xfer_replay_buffer = DiscrimTrainBuffer( self.buffer_num_samples, samples) xfer_variant_mask = ~train_variant_mask assert torch.any(xfer_variant_mask), \ "xfer_adv_weight>0 supplied, but no xfer variants in batch?" assert (xfer_variant_mask[:1] == xfer_variant_mask).all() filtered_samples_xfer = samples[:, xfer_variant_mask[0]] self.xfer_replay_buffer.append_samples(filtered_samples_xfer) if self.final_layer_only_mode: print("Snapping final layer to its optimal value") _optimal_final_layer( model=self.model, expert_batch_iter=self.expert_batch_iter, pol_replay_buffer=self.pol_replay_buffer, batch_size=self.batch_size, n_eval_batches=int( np.ceil(self.final_layer_only_mode_n_samples / self.batch_size)), aug_model=self.aug_model, device=self.dev) # switch to train mode before taking any steps self.model.train() info_dicts = [] for _ in range(self.updates_per_itr): self.opt.zero_grad() expert_data = next(self.expert_batch_iter) expert_obs = expert_data['obs'] expert_acts = expert_data['acts'] # grep rlpyt source for "SamplesFromReplay" to see what fields # pol_replay_samples has sub_batch_size = self.batch_size // 2 pol_replay_samples = self.pol_replay_buffer.sample_batch( sub_batch_size) pol_replay_samples = torchify_buffer(pol_replay_samples) novice_obs = pol_replay_samples.all_observation if self.xfer_adv_model: # add a bunch of of domain transfer samples xfer_replay_samples = self.xfer_replay_buffer.sample_batch( sub_batch_size) xfer_replay_samples = torchify_buffer(xfer_replay_samples) xfer_replay_obs = xfer_replay_samples.all_observation all_obs = tree_map( lambda *args: torch.cat(args, 0).to(self.dev), expert_obs, novice_obs, xfer_replay_obs) all_acts = torch.cat([expert_acts.to(torch.int64), pol_replay_samples.all_action, xfer_replay_samples.all_action], dim=0) \ .to(self.dev) else: all_obs = tree_map( lambda *args: torch.cat(args, 0).to(self.dev), expert_obs, novice_obs) all_acts = torch.cat([expert_acts.to(torch.int64), pol_replay_samples.all_action], dim=0) \ .to(self.dev) if self.aug_model is not None: # augmentations aug_frames = self.aug_model(all_obs.observation) all_obs = all_obs._replace(observation=aug_frames) make_ones = functools.partial(torch.ones, dtype=torch.float32, device=self.dev) make_zeros = functools.partial(torch.zeros, dtype=torch.float32, device=self.dev) is_real_label = torch.cat( [ # expert samples make_ones(sub_batch_size), # novice samples make_zeros(sub_batch_size), ], dim=0) if self.xfer_adv_model: # apply domain transfer loss to the transfer samples, then # remove them logits_all, disc_feats_all = self.model(all_obs, all_acts, return_feats=True) # cut the expert samples out of the discriminator transfer # objective disc_feats_xfer = disc_feats_all[sub_batch_size:] xfer_labels = torch.cat( [make_ones(sub_batch_size), make_zeros(sub_batch_size)], dim=0) xfer_loss, xfer_acc = self.xfer_adv_model( disc_feats_xfer, xfer_labels) # cut the transfer env samples out of the logits logits = logits_all[:-sub_batch_size] else: if self.final_layer_only_mode: # cutting out gradients for the main model evaluation # hopefully makes this a bit less memory-intensive with torch.no_grad(): logits = self.model(all_obs, all_acts) else: logits = self.model(all_obs, all_acts) if self.use_wgan: # Now assign label +1 for novice, -1 for expert. This means we # are trying to push novice scores down, and expert scores up # (I don't know whether this matches the original paper). pm_labels = 1 - 2 * is_real_label wgan_loss = main_loss = torch.mean(pm_labels * logits) else: # GAIL discriminator *objective* is E_fake[log D(s,a)] + # E_expert[log(1-D(s,a))]. You actually want to maximise this; # in reality the "loss" to be minimised is -E_fake[log D(s,a)] # - E_expert[log(1-D(s,a))]. # # binary_cross_entropy_with_logits computes -labels * # log(sigmoid(logits)) - (1 - labels) * log(1-sigmoid(logits)) # (per PyTorch docs). In other words: -y * log D(s,a) - (1 - y) # * log(1 - D(s, a)) # # Hence, GAIL is like logistic regression with label 1 for the # novice and 0 for the expert. This is kind of weird, because # you actually want to *minimise* the discriminator's output. # Indeed, in the actual implementation, they flip this & use 1 # for the expert and 0 for the novice. # In light of all the above, I'm using the OPPOSITE convention # to the paper, but the same convention as the implementation. # To wit: # # - 1 is the *expert* label, and high = more expert-like # - 0 is the *novice* label, and low = less expert-like xent_loss = main_loss = F.binary_cross_entropy_with_logits( logits, is_real_label, reduction='mean') loss = main_loss if self.gp_weight: eps = torch.rand((sub_batch_size, )).to(self.dev) preproc_obs, preproc_acts, _ = self.model.preproc_obs_acts( all_obs, all_acts) interp_obs = hacky_interpolate(preproc_obs, eps, sub_batch_size) interp_acts = hacky_interpolate(preproc_acts, eps, sub_batch_size) interp_task_ids = hacky_interpolate(all_obs.task_id, eps, sub_batch_size) # using the strategy from # https://github.com/caogang/wgan-gp/blob/master/gan_mnist.py # (also we only do this w.r.t images, not scalar inputs) interp_obs_var = torch.autograd.Variable(interp_obs, requires_grad=True) interp_logits, _, _ = self.model.forward_no_preproc( interp_obs_var, interp_acts, interp_task_ids) start_grad = torch.ones((sub_batch_size, ), device=self.dev) # FIXME(sam): I suspect retrain_graph and grad_outputs are not # actually required. Should dig deeper some time. grads_wrt_inputs, = torch.autograd.grad( outputs=interp_logits, inputs=[interp_obs_var], grad_outputs=start_grad, create_graph=True, retain_graph=True, only_inputs=True) grads_wrt_inputs_flat = torch.flatten(grads_wrt_inputs, start_dim=1) grad_norms = torch.sqrt( torch.sum(grads_wrt_inputs_flat**2, dim=1)) grad_penalty = torch.mean((grad_norms - 1.0)**2) loss = loss + self.gp_weight * grad_penalty if self.xfer_adv_model: if self.xfer_disc_anneal: progress = min(1, max(0, itr / float(n_itr))) xfer_adv_weight = progress * self.xfer_adv_weight else: xfer_adv_weight = self.xfer_adv_weight loss = loss + xfer_adv_weight * xfer_loss if self.final_layer_only_mode: print("Skipping actual optimiser step") else: loss.backward() self.opt.step() if self.use_wgan or self.final_layer_only_mode: # center logits so that 0 = uncertain stat_logits = logits - logits.mean() else: # for vanilla GAN, it's fine not to stat_logits = logits # for logging; we'll average theses later info_dict = _compute_gail_stats(stat_logits, is_real_label, self.use_wgan) info_dict['discLoss'] = loss.item() if self.use_wgan: info_dict['discXentLoss'] = float('nan') info_dict['discWGANLoss'] = wgan_loss.item() else: info_dict['discXentLoss'] = xent_loss.item() info_dict['discWGANLoss'] = float('nan') if self.gp_weight: info_dict['discGPGradNorm'] = grad_norms.mean().item() info_dict['discGPLoss'] = grad_penalty.item() else: info_dict['discGPGradNorm'] = float('nan') info_dict['discGPLoss'] = float('nan') if self.xfer_adv_model: info_dict['xferLoss'] = xfer_loss.item() info_dict['xferAcc'] = xfer_acc.item() else: info_dict['xferLoss'] = 0.0 info_dict['xferAcc'] = 0.0 info_dicts.append(info_dict) # switch back to eval mode self.model.eval() opt_info = GAILInfo(**{ k: np.mean([d[k] for d in info_dicts]) for k in info_dicts[0].keys() }) return opt_info
def _optimal_final_layer(model, expert_batch_iter, pol_replay_buffer, batch_size, n_eval_batches, aug_model, device): """Snap the final layer weights to their optimal values under apprenticeship learning loss. Implicitly constrained so that ||w||<=1, ||b||=0 (||b||=0 is fine for app. learning).""" assert isinstance(model.mt_logits, SingleTaskAffineLayer), \ "right now this only works with the single-task affine layer" assert not model.mt_logits.use_sn, \ "currently this is also incompatible with spectral norm" all_feats_expert = [] all_feats_novice = [] with torch.no_grad(): # we need the model in eval mode for this old_training = model.training model.eval() for _ in range(n_eval_batches): expert_data = next(expert_batch_iter) expert_obs = expert_data['obs'] expert_acts = expert_data['acts'] sub_batch_size = batch_size // 2 pol_replay_samples = pol_replay_buffer.sample_batch(sub_batch_size) pol_replay_samples = torchify_buffer(pol_replay_samples) novice_obs = pol_replay_samples.all_observation all_obs = tree_map(lambda *args: torch.cat(args, 0).to(device), expert_obs, novice_obs) all_acts = torch.cat( [expert_acts.to(torch.int64), pol_replay_samples.all_action], dim=0).to(device) if aug_model is not None: # augmentations aug_frames = aug_model(all_obs.observation) all_obs = all_obs._replace(observation=aug_frames) _, disc_feats_unproc = model(all_obs, all_acts, return_feats=True) # FIXME(sam): this should be done *inside* the model! disc_feats_all = model.postproc(disc_feats_unproc) disc_feats_expert = disc_feats_all[:sub_batch_size] disc_feats_novice = disc_feats_all[sub_batch_size:] all_feats_expert.extend(disc_feats_expert.cpu().numpy()) all_feats_novice.extend(disc_feats_novice.cpu().numpy()) # put us back into the right mode now that we're done with eval model.train(old_training) # now set the weights appropriately expert_mean = np.mean(all_feats_expert, axis=0) novice_mean = np.mean(all_feats_novice, axis=0) weight_diff = expert_mean - novice_mean weight_diff_norm = np.linalg.norm(weight_diff) opt_weights = weight_diff / max(weight_diff_norm, 1e-5) # lin_layer.bias is of shape [out_features], while lin_layer.weight is of # shape [out_features, in_features]. lin_layer = model.mt_logits.lin_layer # zero out the bias; it won't make a difference to the (expert mean - # novice mean) objective lin_layer.bias[:] = 0.0 # the extra None is so that we have a [1,feature_dim]-shaped tensor lin_layer.weight[:] = lin_layer.weight.new_tensor(opt_weights[None])
def _to_dev(self, item): return tree_map(lambda *args: torch.cat(args, 0).to(self.dev), item)
def evaluate(self, obs_tuple, act_tensor, update_stats=True): # put model into eval mode if necessary old_training = self.reward_model.training if old_training: self.reward_model.eval() with torch.no_grad(): # flatten observations & actions obs_image = obs_tuple.observation old_dev = obs_image.device lead_dim, T, B, _ = infer_leading_dims(obs_image, self.obs_dims) # use tree_map so we are able to handle the namedtuple directly obs_flat = tree_map( lambda t: t.view((T * B, ) + t.shape[lead_dim:]), obs_tuple) act_flat = act_tensor.view((T * B, ) + act_tensor.shape[lead_dim:]) # now evaluate one batch at a time reward_tensors = [] for b_start in range(0, T * B, self.batch_size): obs_batch = obs_flat[b_start:b_start + self.batch_size] act_batch = act_flat[b_start:b_start + self.batch_size] dev_obs = tree_map(lambda t: t.to(self.dev), obs_batch) dev_acts = act_batch.to(self.dev) dev_reward = self.reward_model(dev_obs, dev_acts) reward_tensors.append(dev_reward.to(old_dev)) # join together the batch results new_reward_flat = torch.cat(reward_tensors, 0) new_reward = restore_leading_dims(new_reward_flat, lead_dim, T, B) task_ids = obs_tuple.task_id assert new_reward.shape == task_ids.shape # put back into training mode if necessary if old_training: self.reward_model.train(old_training) # normalise if necessary if self.normalise: mus = [] stds = [] for task_id, averager in enumerate(self.rew_running_averages): if update_stats: id_sub = task_ids.view((-1, )) == task_id if not torch.any(id_sub): continue rew_sub = new_reward.view((-1, ))[id_sub] averager.update(rew_sub) mus.append(averager.mean.item()) stds.append(averager.std.item()) mus = new_reward.new_tensor(mus) stds = new_reward.new_tensor(stds) denom = torch.max(stds.new_tensor(1e-3), stds / self.target_std) denom_sub = denom[task_ids] mu_sub = mus[task_ids] # only bother applying result if we've actually seen an update # before (otherwise reward will be insane) new_reward = (new_reward - mu_sub) / denom_sub return new_reward