Exemple #1
0
def testall(state_dict_or_model_path, env_name, seed, fps, write_latex,
            latex_alg_name, n_rollouts, run_id, write_csv, gpu_idx, cpu_list,
            batch_size, load_latest):
    """Run quantitative evaluation on all test variants of a given
    environment."""
    # TODO: is there some way of factoring this init code out? Maybe put into
    # Click base command so that it gets run for `train`, `testall`, etc.
    set_seeds(seed)
    import magical
    magical.register_envs()

    # for parallel GPU/CPU sampling
    mp.set_start_method('spawn')
    use_gpu = gpu_idx is not None and torch.cuda.is_available()
    dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu])
    print(f"Using device {dev}, seed {seed}")
    if cpu_list is None:
        cpu_list = sample_cpu_list()
    affinity = dict(cuda_idx=gpu_idx if use_gpu else None,
                    workers_cpus=cpu_list)

    if load_latest:
        state_dict_or_model_path = get_latest_path(state_dict_or_model_path)

    if run_id is None:
        run_id = state_dict_or_model_path

    eval_protocol = MTBCEvalProtocol(
        demo_env_name=env_name,
        state_dict_or_model_path=state_dict_or_model_path,
        seed=seed,
        # det_pol=det_pol,
        run_id=run_id,
        gpu_idx=gpu_idx,
        affinity=affinity,
        n_rollouts=n_rollouts,
        batch_size=batch_size,
    )

    # next bit copied from testall() in bc.py
    frame = eval_protocol.do_eval(verbose=True)
    if latex_alg_name is None:
        latex_alg_name = run_id
    frame['latex_alg_name'] = latex_alg_name

    if write_latex:
        latex_str = latexify_results(frame, id_column='latex_alg_name')
        dir_path = os.path.dirname(write_latex)
        if dir_path:
            os.makedirs(dir_path, exist_ok=True)
        with open(write_latex, 'w') as fp:
            fp.write(latex_str)

    if write_csv:
        dir_path = os.path.dirname(write_csv)
        if dir_path:
            os.makedirs(dir_path, exist_ok=True)
        frame.to_csv(write_csv)

    return frame
Exemple #2
0
def main(demo, save_dir, preproc):
    magical.register_envs()

    demo_dict, = load_demos([demo])
    # keys of the demo are env_name, trajectory, score
    orig_env_name = demo_dict['env_name']
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    print(f"Will save all demos to '{save_dir}'")
    for preproc_name in preproc:
        preprocessed, = preprocess_demos_with_wrapper(
            [demo_dict['trajectory']],
            orig_env_name,
            preproc_name=preproc_name)
        print(f"Working on preprocessor '{preproc_name}'")
        for frame_idx, frame in enumerate(preprocessed.obs):
            frame_fname = 'frame-' + preproc_name.lower() \
                + f'-{frame_idx:03}.png'
            frame_path = os.path.join(save_dir, frame_fname)
            assert frame.shape[-1] % 3 == 0
            frame_stacked = frame.reshape(frame.shape[:-1] + (-1, 3))
            frame_trans = frame_stacked.transpose((2, 0, 1, 3))
            frame_wide = np.concatenate(frame_trans, axis=1)
            print(f"  Writing frame to '{frame_path}'")
            # make much larger so we can see pixel boundaries
            frame_wide = np.repeat(np.repeat(frame_wide, 8, axis=0), 8, axis=1)
            imageio.imsave(frame_path, frame_wide)
Exemple #3
0
def process_directories(model_dirs, out_dir, gpu_idx=0):
    """Load the latest snapshot in each of the given model directories, and
    render a video for each test variant corresponding to the train variants in
    the model. Note that the video will show all models in sequence."""
    set_seeds(DEFAULT_SEED)
    register_envs()

    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    if gpu_idx is None:
        dev = torch.device('cpu')
    else:
        dev = torch.device(f'cuda:{gpu_idx}')

    # we're going to keep a video writer for each env
    writers_by_env = {}

    for model_idx, model_dir in enumerate(model_dirs):
        print(f"Loading model from '{model_dir}' "
              f"(model {model_idx+1}/{len(model_dirs)})")
        latest_path = get_latest_path(os.path.join(model_dir,
                                                   "itr_LATEST.pkl"))
        model = load_state_dict_or_model(latest_path)
        model = model.to(dev)
        test_envs = get_test_envs(model.env_ids_and_names)

        for test_env in test_envs:
            if test_env not in writers_by_env:
                vid_name = generate_vid_name(test_env)
                out_path = os.path.join(out_dir, vid_name)
                print(f"  Writing video for '{test_env}' to '{out_path}'")
                writer = vidio.FFmpegWriter(out_path,
                                            outputdict={
                                                '-r': str(DEFAULT_FPS),
                                                '-vcodec': 'libx264',
                                                '-pix_fmt': 'yuv420p',
                                            })
                writers_by_env[test_env] = writer
            else:
                print(f"  Writing video for '{test_env}'")
                writer = writers_by_env.get(test_env)
            # will be something like, e.g., "movetocorner-testjitter.mp4"
            for frame in generate_frames(model=model,
                                         env_name=test_env,
                                         dev=dev,
                                         seed=DEFAULT_SEED + model_idx,
                                         ntraj=NTRAJ_PER_MODEL,
                                         fps=DEFAULT_FPS):
                writer.writeFrame(frame)

    for writer in writers_by_env.values():
        writer.close()
Exemple #4
0
def _get_env_meta_target(env_names, rv_dict):
    register_envs()  # in case this proc was spawned
    metas = []
    for env_name in env_names:
        # construct a bunch of envs in turn to get info about their observation
        # spaces, action spaces, etc.
        env = gym.make(env_name)
        spec = FilteredSpec(*(getattr(env.spec, field)
                              for field in FilteredSpec._fields))
        meta = EnvMeta(observation_space=env.observation_space,
                       action_space=env.action_space,
                       spec=spec)
        metas.append(meta)
        env.close()
    rv_dict['result'] = tuple(metas)
Exemple #5
0
def test(state_dict_or_model_path, env_name, det_pol, seed, fps, transfer_to):
    """Repeatedly roll out a policy on a given environment. Mostly useful for
    visual debugging; see `testall` for quantitative evaluation."""
    set_seeds(seed)

    import magical
    magical.register_envs()

    # build env
    if transfer_to:
        env = gym.make(transfer_to)
    else:
        env = gym.make(env_name)

    model = load_state_dict_or_model(state_dict_or_model_path)
    ft_wrapper = wrap_model_for_fixed_task(model, env_name)

    spf = 1.0 / (env.fps if fps is None else fps)
    act_range = np.arange(env.action_space.n)
    obs = env.reset()
    try:
        while env.viewer.isopen:
            # for limiting FPS
            frame_start = time.time()
            # return value is actions, values, states, neglogp
            torch_obs = torch.from_numpy(obs)
            with torch.no_grad():
                (pi_torch, ), _ = ft_wrapper(torch_obs[None], None, None)
                pi = pi_torch.cpu().numpy()
            if det_pol:
                action = np.argmax(pi)
            else:
                # numpy is super complain-y about things "not summing to 1"
                pi = pi / sum(pi)
                action = np.random.choice(act_range, p=pi)
            obs, rew, done, info = env.step(action)
            obs = np.asarray(obs)
            env.render(mode='human')
            if done:
                print(f"Done, score {info['eval_score']:.4g}/1.0")
                obs = env.reset()
            elapsed = time.time() - frame_start
            if elapsed < spf:
                time.sleep(spf - elapsed)
    finally:
        env.viewer.close()
Exemple #6
0
def main(ntraj, env_name, seed):
    """Very simple script to benchmark performance of a particular
    environment. Will simply run a ~hundred trajectories or so and record
    performance."""
    magical.register_envs()
    env = gym.make(env_name)
    env.seed(seed)
    env.action_space.seed(seed)

    dtime = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
    out_filename = f'profile-{env_name.lower()}-{dtime}.cprofile'
    print(f"Will write profile to '{out_filename}'")

    cProfile.runctx('do_eval(env, ntraj)',
                    globals(),
                    locals(),
                    filename=out_filename)

    print("Done")
Exemple #7
0
def main(model_dirs, out_root, nprocs, job_ngpus):
    """Load the latest snapshot in `model-dir` (if any), and render a video for
    each test variant corresponding to the train variants in the model."""
    register_envs()

    ray.init(num_cpus=nprocs)
    remote_process_dirs = ray.remote(num_gpus=job_ngpus)(process_directories)

    grouped_dirs = collections.OrderedDict()
    for model_dir in sorted(set(model_dirs)):
        base_dir = os.path.basename(model_dir.strip('/'))
        parsed_alg = parse_run_dir(base_dir)
        grouped_dirs.setdefault(parsed_alg, []).append(model_dir)

    handles = []
    for group_alg, model_dirs in grouped_dirs.items():
        out_dir = os.path.join(out_root, group_alg)
        print(f"Writing videos for group '{group_alg}' ({len(model_dirs)} "
              f"directories) to '{out_dir}'")
        remote_handle = remote_process_dirs.remote(model_dirs, out_dir)
        handles.append(remote_handle)
    ray.get(handles)
Exemple #8
0
def train(demos, add_preproc, seed, batch_size, total_n_batches,
          eval_every_n_batches, out_dir, run_name, gpu_idx, cpu_list,
          eval_n_traj, snapshot_gap, omit_noop, net_width_mul, net_use_bn,
          net_dropout, net_coord_conv, net_attention, net_task_spec_layers,
          load_policy, aug_mode, min_bc):

    # TODO: abstract setup code. Seeds & GPUs should go in one function. Env
    # setup should go in another function (or maybe the same function). Dataset
    # loading should be simplified by having a single class that can provide
    # whatever form of data the current IL method needs, without having to do
    # unnecessary copies in memory. Maybe also just use Sacred, because YOLO.

    with contextlib.ExitStack() as exit_stack:
        # set up seeds & devices
        set_seeds(seed)
        mp.set_start_method('spawn')
        use_gpu = gpu_idx is not None and torch.cuda.is_available()
        dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu])
        print(f"Using device {dev}, seed {seed}")
        if cpu_list is None:
            cpu_list = sample_cpu_list()
        affinity = dict(
            cuda_idx=gpu_idx if use_gpu else None,
            workers_cpus=cpu_list,
        )

        # register original envs
        import magical
        magical.register_envs()

        # TODO: split out part of the dataset for validation.
        demos_metas_dict = get_demos_meta(demo_paths=demos,
                                          omit_noop=omit_noop,
                                          transfer_variants=[],
                                          preproc_name=add_preproc)
        dataset_mt = demos_metas_dict['dataset_mt']
        loader_mt = make_loader_mt(dataset_mt, batch_size)
        variant_groups = demos_metas_dict['variant_groups']
        env_metas = demos_metas_dict['env_metas']
        num_demo_sources = demos_metas_dict['num_demo_sources']
        task_ids_and_demo_env_names = demos_metas_dict[
            'task_ids_and_demo_env_names']
        sampler_batch_B = batch_size
        # this doesn't really matter
        sampler_batch_T = 5
        sampler, sampler_batch_B = make_mux_sampler(
            variant_groups=variant_groups,
            num_demo_sources=num_demo_sources,
            env_metas=env_metas,
            use_gpu=use_gpu,
            batch_B=sampler_batch_B,
            batch_T=sampler_batch_T,
            # TODO: instead of doing this, try sampling in proportion to length
            # of horizon; that should get more samples from harder envs
            task_var_weights=None)
        if load_policy is not None:
            try:
                pol_path = get_latest_path(load_policy)
            except ValueError:
                pol_path = load_policy
            policy_ctor = functools.partial(
                adapt_pol_loader,
                pol_path=pol_path,
                task_ids_and_demo_env_names=task_ids_and_demo_env_names)
            policy_kwargs = {}
        else:
            policy_kwargs = {
                'env_ids_and_names': task_ids_and_demo_env_names,
                'width': net_width_mul,
                'use_bn': net_use_bn,
                'dropout': net_dropout,
                'coord_conv': net_coord_conv,
                'attention': net_attention,
                'n_task_spec_layers': net_task_spec_layers,
                **get_policy_spec_magical(env_metas),
            }
            policy_ctor = MultiHeadPolicyNet
        agent = CategoricalPgAgent(ModelCls=MuxTaskModelWrapper,
                                   model_kwargs=dict(
                                       model_ctor=policy_ctor,
                                       model_kwargs=policy_kwargs))
        sampler.initialize(agent=agent,
                           seed=np.random.randint(1 << 31),
                           affinity=affinity)
        exit_stack.callback(lambda: sampler.shutdown())

        model_mt = policy_ctor(**policy_kwargs).to(dev)
        if min_bc:
            num_tasks = len(task_ids_and_demo_env_names)
            weight_mod = MinBCWeightingModule(num_tasks, num_demo_sources) \
                .to(dev)
            all_params = it.chain(model_mt.parameters(),
                                  weight_mod.parameters())
        else:
            weight_mod = None
            all_params = model_mt.parameters()
        # Adam mostly works fine, but in very loose informal tests it seems
        # like SGD had fewer weird failures where mean loss would jump up by a
        # factor of 2x for a period (?). (I don't think that was solely due to
        # high LR; probably an architectural issue.) opt_mt =
        # torch.optim.Adam(model_mt.parameters(), lr=3e-4)
        opt_mt = torch.optim.SGD(all_params, lr=1e-3, momentum=0.1)

        try:
            aug_opts = MILBenchAugmentations.PRESETS[aug_mode]
        except KeyError:
            raise ValueError(f"unsupported mode '{aug_mode}'")
        if aug_opts:
            print("Augmentations:", ", ".join(aug_opts))
            aug_model = MILBenchAugmentations(**{k: True for k in aug_opts}) \
                .to(dev)
        else:
            print("No augmentations")
            aug_model = None

        n_uniq_envs = len(task_ids_and_demo_env_names)
        log_params = {
            'n_uniq_envs': n_uniq_envs,
            'n_demos': len(demos),
            'net_use_bn': net_use_bn,
            'net_width_mul': net_width_mul,
            'net_dropout': net_dropout,
            'net_coord_conv': net_coord_conv,
            'net_attention': net_attention,
            'aug_mode': aug_mode,
            'seed': seed,
            'omit_noop': omit_noop,
            'batch_size': batch_size,
            'eval_n_traj': eval_n_traj,
            'eval_every_n_batches': eval_every_n_batches,
            'total_n_batches': total_n_batches,
            'snapshot_gap': snapshot_gap,
            'add_preproc': add_preproc,
            'net_task_spec_layers': net_task_spec_layers,
        }
        with make_logger_ctx(out_dir,
                             "mtbc",
                             f"mt{n_uniq_envs}",
                             run_name,
                             snapshot_gap=snapshot_gap,
                             log_params=log_params):
            # initial save
            torch.save(
                model_mt,
                os.path.join(logger.get_snapshot_dir(), 'full_model.pt'))

            # train for a while
            n_batches_done = 0
            n_rounds = int(np.ceil(total_n_batches / eval_every_n_batches))
            rnd = 1
            assert eval_every_n_batches > 0
            while n_batches_done < total_n_batches:
                batches_left_now = min(total_n_batches - n_batches_done,
                                       eval_every_n_batches)
                print(f"Done {n_batches_done}/{total_n_batches} "
                      f"({n_batches_done/total_n_batches*100:.2f}%, "
                      f"{rnd}/{n_rounds} rounds) batches; doing another "
                      f"{batches_left_now}")
                model_mt.train()
                loss_ewma, losses, per_task_losses = do_training_mt(
                    loader=loader_mt,
                    model=model_mt,
                    opt=opt_mt,
                    dev=dev,
                    aug_model=aug_model,
                    min_bc_module=weight_mod,
                    n_batches=batches_left_now)

                # TODO: record accuracy on a random subset of the train and
                # validation sets (both in eval mode, not train mode)

                print(f"Evaluating {eval_n_traj} trajectories on "
                      f"{variant_groups.num_tasks} tasks")
                record_misc_calls = []
                model_mt.eval()

                copy_model_into_agent_eval(model_mt, sampler.agent)
                scores_by_tv = eval_model(
                    sampler,
                    # shouldn't be any exploration
                    itr=0,
                    n_traj=eval_n_traj)
                for (task_id, variant_id), scores in scores_by_tv.items():
                    tv_id = (task_id, variant_id)
                    env_name = variant_groups.env_name_by_task_variant[tv_id]
                    tag = make_env_tag(strip_mb_preproc_name(env_name))
                    logger.record_tabular_misc_stat("Score%s" % tag, scores)
                    env_losses = per_task_losses.get(tv_id, [])
                    record_misc_calls.append((f"Loss{tag}", env_losses))

                # we record score AFTER loss so that losses are all in one
                # place, and scores are all in another
                for args in record_misc_calls:
                    logger.record_tabular_misc_stat(*args)

                # finish logging for this epoch
                logger.record_tabular("Round", rnd)
                logger.record_tabular("LossEWMA", loss_ewma)
                logger.record_tabular_misc_stat("Loss", losses)
                logger.dump_tabular()
                logger.save_itr_params(
                    rnd, {
                        'model_state': model_mt.state_dict(),
                        'opt_state': opt_mt.state_dict(),
                    })

                # advance ctrs
                rnd += 1
                n_batches_done += batches_left_now
Exemple #9
0
 def __init__(self, env_name, **kwargs):
     register_envs()
     env = gym.make(env_name)
     super().__init__(env, **kwargs)
"""Test rollouts in every environment."""
import gym
import pytest

import magical

N_ROLLOUTS = 2

magical.register_envs()


def test_registered_envs():
    # make sure we registered at least some environments
    assert len(magical.ALL_REGISTERED_ENVS) > 8


@pytest.mark.parametrize('env_name', magical.ALL_REGISTERED_ENVS)
def test_rollouts(env_name):
    """Simple smoke test to make sure environments can roll out trajectories of
    the right length."""
    env = gym.make(env_name)
    try:
        env.seed(7)
        env.action_space.seed(42)
        env.reset()
        for _ in range(N_ROLLOUTS):
            done = False
            traj_len = 0
            while not done:
                action = env.action_space.sample()
                obs, rew, done, info = env.step(action)
Exemple #11
0
def main(
        demos,
        add_preproc,
        seed,
        sampler_batch_B,
        sampler_batch_T,
        disc_batch_size,
        out_dir,
        run_name,
        gpu_idx,
        disc_up_per_iter,
        total_n_steps,
        log_interval_steps,
        cpu_list,
        snapshot_gap,
        load_policy,
        bc_loss,
        omit_noop,
        disc_replay_mult,
        disc_aug,
        ppo_aug,
        disc_use_bn,
        disc_net_attn,
        disc_use_sn,
        disc_gp_weight,
        disc_al,
        disc_al_dim,
        disc_al_nsamples,
        disc_ae_pretrain_iters,
        wgan,
        transfer_variants,
        transfer_disc_loss_weight,
        transfer_pol_loss_weight,
        transfer_disc_anneal,
        transfer_pol_batch_weight,
        danger_debug_reward_weight,
        danger_override_env_name,
        # new sweep hyperparams:
        disc_lr,
        disc_use_act,
        disc_all_frames,
        ppo_lr,
        ppo_gamma,
        ppo_lambda,
        ppo_ent,
        ppo_adv_clip,
        ppo_norm_adv,
        ppo_use_bn,
        ppo_minibatches,
        ppo_epochs):
    # set up seeds & devices
    # TODO: also seed child envs, when rlpyt supports it
    set_seeds(seed)
    # 'spawn' is necessary to use GL envs in subprocesses. For whatever reason
    # they don't play nice after a fork. (But what about set_seeds() in
    # subprocesses? May need to hack CpuSampler and GpuSampler.)
    mp.set_start_method('spawn')
    use_gpu = gpu_idx is not None and torch.cuda.is_available()
    dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu])
    if cpu_list is None:
        cpu_list = sample_cpu_list()
    # FIXME: I suspect current solution will set torch_num_threads suboptimally
    affinity = dict(cuda_idx=gpu_idx if use_gpu else None,
                    workers_cpus=cpu_list)
    print(f"Using device {dev}, seed {seed}, affinity {affinity}")

    # register original envs
    import magical
    magical.register_envs()

    if danger_override_env_name:
        raise NotImplementedError(
            "haven't re-implemeneted env name override for multi-task GAIL")

    demos_metas_dict = get_demos_meta(demo_paths=demos,
                                      omit_noop=omit_noop,
                                      transfer_variants=transfer_variants,
                                      preproc_name=add_preproc)
    dataset_mt = demos_metas_dict['dataset_mt']
    variant_groups = demos_metas_dict['variant_groups']
    env_metas = demos_metas_dict['env_metas']
    task_ids_and_demo_env_names = demos_metas_dict[
        'task_ids_and_demo_env_names']
    task_var_weights = {
        (task, variant): 1.0 if variant == 0 else transfer_pol_batch_weight
        for task, variant in variant_groups.env_name_by_task_variant
    }
    sampler, sampler_batch_B = make_mux_sampler(
        variant_groups=variant_groups,
        task_var_weights=task_var_weights,
        env_metas=env_metas,
        use_gpu=use_gpu,
        num_demo_sources=0,  # not important for now
        batch_B=sampler_batch_B,
        batch_T=sampler_batch_T)

    policy_kwargs = {
        'use_bn': ppo_use_bn,
        'env_ids_and_names': task_ids_and_demo_env_names,
        **get_policy_spec_magical(env_metas),
    }
    policy_ctor = MultiHeadPolicyNet
    ppo_agent = CategoricalPgAgent(ModelCls=MuxTaskModelWrapper,
                                   model_kwargs=dict(
                                       model_ctor=policy_ctor,
                                       model_kwargs=policy_kwargs))

    print("Setting up discriminator/reward model")
    disc_fc_dim = 256
    disc_final_feats_dim = disc_al_dim if disc_al else disc_fc_dim
    discriminator_mt = MILBenchDiscriminatorMT(
        task_ids_and_names=task_ids_and_demo_env_names,
        in_chans=policy_kwargs['in_chans'],
        act_dim=policy_kwargs['n_actions'],
        use_all_chans=disc_all_frames,
        use_actions=disc_use_act,
        # can supply any argument that goes to MILBenchFeatureNetwork (e.g.
        # dropout, use_bn, width, etc.)
        attention=disc_net_attn,
        use_bn=disc_use_bn,
        use_sn=disc_use_sn,
        fc_dim=disc_fc_dim,
        final_feats_dim=disc_final_feats_dim,
    ).to(dev)

    if (not transfer_variants
            and (transfer_disc_loss_weight or transfer_pol_loss_weight)):
        print("No xfer variants supplied, setting xfer disc loss term to zero")
        transfer_disc_loss_weight = 0.0
        transfer_pol_loss_weight = 0.0
    if transfer_pol_loss_weight > 0:
        assert transfer_disc_loss_weight > 0
    if transfer_variants and transfer_disc_loss_weight:
        xfer_adv_module = BinaryDomainLossModule(
            discriminator_mt.ret_feats_dim).to(dev)
    else:
        xfer_adv_module = None

    reward_model_mt = RewardModel(
        discriminator_mt,
        xfer_adv_module,
        transfer_pol_loss_weight,
        # In apprenticeship learning we can just pass
        # the model outputs straight through, just
        # like in WGAN.
        use_wgan=wgan or disc_al).to(dev)
    reward_evaluator_mt = RewardEvaluatorMT(
        task_ids_and_names=task_ids_and_demo_env_names,
        reward_model=reward_model_mt,
        obs_dims=3,
        batch_size=disc_batch_size,
        normalise=True,
        # I think I had rewards in [0,0.01] in
        # the PPO run that I got to run with a
        # manually-defined reward.
        target_std=0.01)

    ppo_hyperparams = dict(
        learning_rate=ppo_lr,
        discount=ppo_gamma,
        entropy_loss_coeff=ppo_ent,  # was working at 0.003 and 0.001
        gae_lambda=ppo_lambda,
        ratio_clip=ppo_adv_clip,
        minibatches=ppo_minibatches,
        epochs=ppo_epochs,
        value_loss_coeff=1.0,
        clip_grad_norm=1.0,
        normalize_advantage=ppo_norm_adv,
    )
    if bc_loss:
        # TODO: make this batch size configurable
        ppo_loader_mt = make_loader_mt(
            dataset_mt, max(16, min(64, sampler_batch_T * sampler_batch_B)))
    else:
        ppo_loader_mt = None

    # FIXME: abstract code for constructing augmentation model from presets
    try:
        ppo_aug_opts = MILBenchAugmentations.PRESETS[ppo_aug]
    except KeyError:
        raise ValueError(f"unsupported augmentation mode '{ppo_aug}'")
    if ppo_aug_opts:
        print("Policy augmentations:", ", ".join(ppo_aug_opts))
        ppo_aug_model = MILBenchAugmentations(
            **{k: True
               for k in ppo_aug_opts}).to(dev)
    else:
        print("No policy augmentations")
        ppo_aug_model = None

    ppo_algo = BCCustomRewardPPO(bc_loss_coeff=bc_loss,
                                 expert_traj_loader=ppo_loader_mt,
                                 true_reward_weight=danger_debug_reward_weight,
                                 aug_model=ppo_aug_model,
                                 **ppo_hyperparams)
    ppo_algo.set_reward_evaluator(reward_evaluator_mt)

    print("Setting up optimiser")
    try:
        aug_opts = MILBenchAugmentations.PRESETS[disc_aug]
    except KeyError:
        raise ValueError(f"unsupported augmentation mode '{disc_aug}'")
    if aug_opts:
        print("Discriminator augmentations:", ", ".join(aug_opts))
        aug_model = MILBenchAugmentations(**{k: True for k in aug_opts}) \
            .to(dev)
    else:
        print("No discriminator augmentations")
        aug_model = None
    gail_optim = GAILOptimiser(
        dataset_mt=dataset_mt,
        discrim_model=discriminator_mt,
        buffer_num_samples=max(
            disc_batch_size,
            disc_replay_mult * sampler_batch_T * sampler_batch_B),
        batch_size=disc_batch_size,
        updates_per_itr=disc_up_per_iter,
        gp_weight=disc_gp_weight,
        dev=dev,
        aug_model=aug_model,
        lr=disc_lr,
        xfer_adv_weight=transfer_disc_loss_weight,
        xfer_adv_anneal=transfer_disc_anneal,
        xfer_adv_module=xfer_adv_module,
        final_layer_only_mode=disc_al,
        final_layer_only_mode_n_samples=disc_al_nsamples,
        use_wgan=wgan)

    if disc_ae_pretrain_iters:
        # FIXME(sam): pass n_acts, obs_chans, lr to AETrainer
        ae_trainer = AETrainer(discriminator=discriminator_mt,
                               disc_out_size=disc_final_feats_dim,
                               data_batch_iter=gail_optim.expert_batch_iter,
                               dev=dev)

    print("Setting up RL algorithm")
    # signature for arg: reward_model(obs_tensor, act_tensor) -> rewards
    runner = GAILMinibatchRl(
        seed=seed,
        gail_optim=gail_optim,
        variant_groups=variant_groups,
        algo=ppo_algo,
        agent=ppo_agent,
        sampler=sampler,
        # n_steps controls total number of environment steps we take
        n_steps=total_n_steps,
        # log_interval_steps controls how many environment steps we take
        # between making log outputs (doing N environment steps takes roughly
        # the same amount of time no matter what batch_B, batch_T, etc. are, so
        # this gives us a fairly constant interval between log outputs)
        log_interval_steps=log_interval_steps,
        affinity=affinity)

    # TODO: factor out this callback
    def init_policy_cb(runner):
        """Callback which gets called once after Runner startup to save an
        initial policy model, and optionally load saved parameters."""
        # get state of newly-initalised model
        wrapped_model = runner.algo.agent.model
        assert wrapped_model is not None, "has ppo_agent been initalised?"
        unwrapped_model = wrapped_model.model

        if load_policy:
            print(f"Loading policy from '{load_policy}'")
            saved_model = load_state_dict_or_model(load_policy)
            saved_dict = saved_model.state_dict()
            unwrapped_model.load_state_dict(saved_dict)

        real_state = unwrapped_model.state_dict()

        # make a clone model so we can pickle it, and copy across weights
        policy_copy_mt = policy_ctor(**policy_kwargs).to('cpu')
        policy_copy_mt.load_state_dict(real_state)

        # save it here
        init_pol_snapshot_path = os.path.join(logger.get_snapshot_dir(),
                                              'full_model.pt')
        torch.save(policy_copy_mt, init_pol_snapshot_path)

    print("Training!")
    n_uniq_envs = variant_groups.num_tasks
    log_params = {
        'add_preproc': add_preproc,
        'seed': seed,
        'sampler_batch_T': sampler_batch_T,
        'sampler_batch_B': sampler_batch_B,
        'disc_batch_size': disc_batch_size,
        'disc_up_per_iter': disc_up_per_iter,
        'total_n_steps': total_n_steps,
        'bc_loss': bc_loss,
        'omit_noop': omit_noop,
        'disc_aug': disc_aug,
        'danger_debug_reward_weight': danger_debug_reward_weight,
        'disc_lr': disc_lr,
        'disc_use_act': disc_use_act,
        'disc_all_frames': disc_all_frames,
        'disc_net_attn': disc_net_attn,
        'disc_use_bn': disc_use_bn,
        'ppo_lr': ppo_lr,
        'ppo_gamma': ppo_gamma,
        'ppo_lambda': ppo_lambda,
        'ppo_ent': ppo_ent,
        'ppo_adv_clip': ppo_adv_clip,
        'ppo_norm_adv': ppo_norm_adv,
        'transfer_variants': transfer_variants,
        'transfer_pol_batch_weight': transfer_pol_batch_weight,
        'transfer_pol_loss_weight': transfer_pol_loss_weight,
        'transfer_disc_loss_weight': transfer_disc_loss_weight,
        'transfer_disc_anneal': transfer_disc_anneal,
        'ndemos': len(demos),
        'n_uniq_envs': n_uniq_envs,
    }
    with make_logger_ctx(out_dir,
                         "mtgail",
                         f"mt{n_uniq_envs}",
                         run_name,
                         snapshot_gap=snapshot_gap,
                         log_params=log_params):
        torch.save(
            discriminator_mt,
            os.path.join(logger.get_snapshot_dir(), 'full_discrim_model.pt'))

        if disc_ae_pretrain_iters:
            # FIXME(sam): come up with a better solution for creating these
            # montages (can I do it regularly? Should I put them somewhere
            # other than the snapshot dir?).
            ae_trainer.make_montage(
                os.path.join(logger.get_snapshot_dir(), 'ae-before.png'))
            ae_trainer.do_full_training(disc_ae_pretrain_iters)
            ae_trainer.make_montage(
                os.path.join(logger.get_snapshot_dir(), 'ae-after.png'))

        # note that periodic snapshots get saved by GAILMiniBatchRl, thanks to
        # the overridden get_itr_snapshot() method
        runner.train(cb_startup=init_policy_cb)