Пример #1
0
    def obtain_scores(self, env_name):
        print(f"Testing on {env_name}")

        use_gpu = self.gpu_idx is not None
        if use_gpu:
            SamplerCls = GpuSampler
        else:
            SamplerCls = CpuSampler

        env_ctor = MILBenchGymEnv
        env_ctor_kwargs = dict(env_name=env_name)
        env = gym.make(env_name)
        max_steps = env.spec.max_episode_steps
        env.close()
        del env

        env_sampler = SamplerCls(
            env_ctor,
            env_ctor_kwargs,
            batch_T=max_steps,
            # don't decorrelate, it will f**k up the
            # scores
            max_decorrelation_steps=0,
            batch_B=min(self.n_rollouts, self.batch_size))
        env_agent = CategoricalPgAgent(
            ModelCls=saved_model_loader_ft,
            model_kwargs=dict(
                state_dict_or_model_path=self.state_dict_or_model_path,
                env_name=self.demo_env_name))
        env_sampler.initialize(env_agent,
                               seed=self.seed,
                               affinity=self.affinity)
        dev = torch.device(["cpu", f"cuda:{self.gpu_idx}"][use_gpu])
        env_agent.to_device(dev.index if use_gpu else None)
        try:
            scores = eval_model_st(env_sampler, 0, self.n_rollouts)
        finally:
            env_sampler.shutdown()

        return scores
Пример #2
0
def simulateAgentFile (agentFile, render=False) :
    """ Load rlpyt agent from file and simulate  """
    state_dict = torch.load(
        agentFile, 
        map_location=torch.device('cpu')) 
    agent = CategoricalPgAgent(AcrobotNet)
    env = gym.make('Acrobot-v1')
    EnvSpace = namedtuple('EnvSpace', ['action', 'observation'])
    agent.initialize(EnvSpace(env.action_space, env.observation_space))
    agent.load_state_dict(state_dict)
    simulateAgent(agent, render)
Пример #3
0
def findOptimalAgent(reward, run_ID=0):
    """
    Find the optimal agent for the MDP (see Config for 
    specification) under a custom reward function
    using rlpyt's implementation of A2C.
    """
    cpus = list(range(C.N_PARALLEL))
    affinity = dict(cuda_idx=C.CUDA_IDX, workers_cpus=cpus)
    sampler = SerialSampler(EnvCls=rlpyt_make,
                            env_kwargs=dict(id=C.ENV, reward=reward),
                            batch_T=C.BATCH_T,
                            batch_B=C.BATCH_B,
                            max_decorrelation_steps=400,
                            eval_env_kwargs=dict(id=C.ENV),
                            eval_n_envs=5,
                            eval_max_steps=2500)
    algo = A2C(discount=C.DISCOUNT,
               learning_rate=C.LR,
               value_loss_coeff=C.VALUE_LOSS_COEFF,
               entropy_loss_coeff=C.ENTROPY_LOSS_COEFF)
    agent = CategoricalPgAgent(AcrobotNet)
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=C.N_STEPS,
        log_interval_steps=C.LOG_STEP,
        affinity=affinity,
    )
    name = "a2c_" + C.ENV.lower()
    log_dir = name
    with logger_context(log_dir,
                        run_ID,
                        name,
                        snapshot_mode='last',
                        override_prefix=True):
        runner.train()
    return agent
Пример #4
0
def build_and_train(game="pong", run_ID=0, cuda_idx=None):
    sampler = SerialSampler(
        # EnvCls=MyEnv,
        # env_kwargs=dict(),
        # batch_T=4,  # Four time-steps per sampler iteration.
        # batch_B=1,
        # max_decorrelation_steps=0,
        # eval_n_envs=10,
        # eval_env_kwargs=dict(),
        # eval_max_steps=int(10e3),
        # eval_max_trajectories=5,
        EnvCls=CanvasEnv,
        env_kwargs=dict(),
        batch_T=1,  # 5 time-steps per sampler iteration.
        batch_B=16,  # 16 parallel environments.
        max_decorrelation_steps=400,
    )
    algo = PPO()
    agent = CategoricalPgAgent(
        ModelCls=MyModel,
        model_kwargs=dict(image_shape=(1, CANVAS_WIDTH, CANVAS_WIDTH),
                          output_size=N_ACTIONS),
        initial_model_state_dict=None,
    )
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )

    config = dict()
    name = "dqn_" + game
    log_dir = "example_1"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Пример #5
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = GpuSampler(
        EnvCls=gym.make,
        env_kwargs=config["env"],
        CollectorCls=GpuResetCollector,
        eval_env_kwargs=config["eval_env"],
        **config["sampler"]
    )
    if config["checkpoint"]:
        model_state_dict = torch.load(config["checkpoint"])
    else:
        model_state_dict = None

    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = CategoricalPgAgent(
        ModelCls=BaselinePolicy,
        model_kwargs=config["model"],
        initial_model_state_dict=model_state_dict,
        **config["agent"]
    )
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["id"]

    with logger_context(log_dir, run_ID, name, config, snapshot_mode='last'):
        runner.train()
Пример #6
0
def train(demos, add_preproc, seed, batch_size, total_n_batches,
          eval_every_n_batches, out_dir, run_name, gpu_idx, cpu_list,
          eval_n_traj, snapshot_gap, omit_noop, net_width_mul, net_use_bn,
          net_dropout, net_coord_conv, net_attention, net_task_spec_layers,
          load_policy, aug_mode, min_bc):

    # TODO: abstract setup code. Seeds & GPUs should go in one function. Env
    # setup should go in another function (or maybe the same function). Dataset
    # loading should be simplified by having a single class that can provide
    # whatever form of data the current IL method needs, without having to do
    # unnecessary copies in memory. Maybe also just use Sacred, because YOLO.

    with contextlib.ExitStack() as exit_stack:
        # set up seeds & devices
        set_seeds(seed)
        mp.set_start_method('spawn')
        use_gpu = gpu_idx is not None and torch.cuda.is_available()
        dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu])
        print(f"Using device {dev}, seed {seed}")
        if cpu_list is None:
            cpu_list = sample_cpu_list()
        affinity = dict(
            cuda_idx=gpu_idx if use_gpu else None,
            workers_cpus=cpu_list,
        )

        # register original envs
        import magical
        magical.register_envs()

        # TODO: split out part of the dataset for validation.
        demos_metas_dict = get_demos_meta(demo_paths=demos,
                                          omit_noop=omit_noop,
                                          transfer_variants=[],
                                          preproc_name=add_preproc)
        dataset_mt = demos_metas_dict['dataset_mt']
        loader_mt = make_loader_mt(dataset_mt, batch_size)
        variant_groups = demos_metas_dict['variant_groups']
        env_metas = demos_metas_dict['env_metas']
        num_demo_sources = demos_metas_dict['num_demo_sources']
        task_ids_and_demo_env_names = demos_metas_dict[
            'task_ids_and_demo_env_names']
        sampler_batch_B = batch_size
        # this doesn't really matter
        sampler_batch_T = 5
        sampler, sampler_batch_B = make_mux_sampler(
            variant_groups=variant_groups,
            num_demo_sources=num_demo_sources,
            env_metas=env_metas,
            use_gpu=use_gpu,
            batch_B=sampler_batch_B,
            batch_T=sampler_batch_T,
            # TODO: instead of doing this, try sampling in proportion to length
            # of horizon; that should get more samples from harder envs
            task_var_weights=None)
        if load_policy is not None:
            try:
                pol_path = get_latest_path(load_policy)
            except ValueError:
                pol_path = load_policy
            policy_ctor = functools.partial(
                adapt_pol_loader,
                pol_path=pol_path,
                task_ids_and_demo_env_names=task_ids_and_demo_env_names)
            policy_kwargs = {}
        else:
            policy_kwargs = {
                'env_ids_and_names': task_ids_and_demo_env_names,
                'width': net_width_mul,
                'use_bn': net_use_bn,
                'dropout': net_dropout,
                'coord_conv': net_coord_conv,
                'attention': net_attention,
                'n_task_spec_layers': net_task_spec_layers,
                **get_policy_spec_magical(env_metas),
            }
            policy_ctor = MultiHeadPolicyNet
        agent = CategoricalPgAgent(ModelCls=MuxTaskModelWrapper,
                                   model_kwargs=dict(
                                       model_ctor=policy_ctor,
                                       model_kwargs=policy_kwargs))
        sampler.initialize(agent=agent,
                           seed=np.random.randint(1 << 31),
                           affinity=affinity)
        exit_stack.callback(lambda: sampler.shutdown())

        model_mt = policy_ctor(**policy_kwargs).to(dev)
        if min_bc:
            num_tasks = len(task_ids_and_demo_env_names)
            weight_mod = MinBCWeightingModule(num_tasks, num_demo_sources) \
                .to(dev)
            all_params = it.chain(model_mt.parameters(),
                                  weight_mod.parameters())
        else:
            weight_mod = None
            all_params = model_mt.parameters()
        # Adam mostly works fine, but in very loose informal tests it seems
        # like SGD had fewer weird failures where mean loss would jump up by a
        # factor of 2x for a period (?). (I don't think that was solely due to
        # high LR; probably an architectural issue.) opt_mt =
        # torch.optim.Adam(model_mt.parameters(), lr=3e-4)
        opt_mt = torch.optim.SGD(all_params, lr=1e-3, momentum=0.1)

        try:
            aug_opts = MILBenchAugmentations.PRESETS[aug_mode]
        except KeyError:
            raise ValueError(f"unsupported mode '{aug_mode}'")
        if aug_opts:
            print("Augmentations:", ", ".join(aug_opts))
            aug_model = MILBenchAugmentations(**{k: True for k in aug_opts}) \
                .to(dev)
        else:
            print("No augmentations")
            aug_model = None

        n_uniq_envs = len(task_ids_and_demo_env_names)
        log_params = {
            'n_uniq_envs': n_uniq_envs,
            'n_demos': len(demos),
            'net_use_bn': net_use_bn,
            'net_width_mul': net_width_mul,
            'net_dropout': net_dropout,
            'net_coord_conv': net_coord_conv,
            'net_attention': net_attention,
            'aug_mode': aug_mode,
            'seed': seed,
            'omit_noop': omit_noop,
            'batch_size': batch_size,
            'eval_n_traj': eval_n_traj,
            'eval_every_n_batches': eval_every_n_batches,
            'total_n_batches': total_n_batches,
            'snapshot_gap': snapshot_gap,
            'add_preproc': add_preproc,
            'net_task_spec_layers': net_task_spec_layers,
        }
        with make_logger_ctx(out_dir,
                             "mtbc",
                             f"mt{n_uniq_envs}",
                             run_name,
                             snapshot_gap=snapshot_gap,
                             log_params=log_params):
            # initial save
            torch.save(
                model_mt,
                os.path.join(logger.get_snapshot_dir(), 'full_model.pt'))

            # train for a while
            n_batches_done = 0
            n_rounds = int(np.ceil(total_n_batches / eval_every_n_batches))
            rnd = 1
            assert eval_every_n_batches > 0
            while n_batches_done < total_n_batches:
                batches_left_now = min(total_n_batches - n_batches_done,
                                       eval_every_n_batches)
                print(f"Done {n_batches_done}/{total_n_batches} "
                      f"({n_batches_done/total_n_batches*100:.2f}%, "
                      f"{rnd}/{n_rounds} rounds) batches; doing another "
                      f"{batches_left_now}")
                model_mt.train()
                loss_ewma, losses, per_task_losses = do_training_mt(
                    loader=loader_mt,
                    model=model_mt,
                    opt=opt_mt,
                    dev=dev,
                    aug_model=aug_model,
                    min_bc_module=weight_mod,
                    n_batches=batches_left_now)

                # TODO: record accuracy on a random subset of the train and
                # validation sets (both in eval mode, not train mode)

                print(f"Evaluating {eval_n_traj} trajectories on "
                      f"{variant_groups.num_tasks} tasks")
                record_misc_calls = []
                model_mt.eval()

                copy_model_into_agent_eval(model_mt, sampler.agent)
                scores_by_tv = eval_model(
                    sampler,
                    # shouldn't be any exploration
                    itr=0,
                    n_traj=eval_n_traj)
                for (task_id, variant_id), scores in scores_by_tv.items():
                    tv_id = (task_id, variant_id)
                    env_name = variant_groups.env_name_by_task_variant[tv_id]
                    tag = make_env_tag(strip_mb_preproc_name(env_name))
                    logger.record_tabular_misc_stat("Score%s" % tag, scores)
                    env_losses = per_task_losses.get(tv_id, [])
                    record_misc_calls.append((f"Loss{tag}", env_losses))

                # we record score AFTER loss so that losses are all in one
                # place, and scores are all in another
                for args in record_misc_calls:
                    logger.record_tabular_misc_stat(*args)

                # finish logging for this epoch
                logger.record_tabular("Round", rnd)
                logger.record_tabular("LossEWMA", loss_ewma)
                logger.record_tabular_misc_stat("Loss", losses)
                logger.dump_tabular()
                logger.save_itr_params(
                    rnd, {
                        'model_state': model_mt.state_dict(),
                        'opt_state': opt_mt.state_dict(),
                    })

                # advance ctrs
                rnd += 1
                n_batches_done += batches_left_now
Пример #7
0
def main(
        demos,
        add_preproc,
        seed,
        sampler_batch_B,
        sampler_batch_T,
        disc_batch_size,
        out_dir,
        run_name,
        gpu_idx,
        disc_up_per_iter,
        total_n_steps,
        log_interval_steps,
        cpu_list,
        snapshot_gap,
        load_policy,
        bc_loss,
        omit_noop,
        disc_replay_mult,
        disc_aug,
        ppo_aug,
        disc_use_bn,
        disc_net_attn,
        disc_use_sn,
        disc_gp_weight,
        disc_al,
        disc_al_dim,
        disc_al_nsamples,
        disc_ae_pretrain_iters,
        wgan,
        transfer_variants,
        transfer_disc_loss_weight,
        transfer_pol_loss_weight,
        transfer_disc_anneal,
        transfer_pol_batch_weight,
        danger_debug_reward_weight,
        danger_override_env_name,
        # new sweep hyperparams:
        disc_lr,
        disc_use_act,
        disc_all_frames,
        ppo_lr,
        ppo_gamma,
        ppo_lambda,
        ppo_ent,
        ppo_adv_clip,
        ppo_norm_adv,
        ppo_use_bn,
        ppo_minibatches,
        ppo_epochs):
    # set up seeds & devices
    # TODO: also seed child envs, when rlpyt supports it
    set_seeds(seed)
    # 'spawn' is necessary to use GL envs in subprocesses. For whatever reason
    # they don't play nice after a fork. (But what about set_seeds() in
    # subprocesses? May need to hack CpuSampler and GpuSampler.)
    mp.set_start_method('spawn')
    use_gpu = gpu_idx is not None and torch.cuda.is_available()
    dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu])
    if cpu_list is None:
        cpu_list = sample_cpu_list()
    # FIXME: I suspect current solution will set torch_num_threads suboptimally
    affinity = dict(cuda_idx=gpu_idx if use_gpu else None,
                    workers_cpus=cpu_list)
    print(f"Using device {dev}, seed {seed}, affinity {affinity}")

    # register original envs
    import magical
    magical.register_envs()

    if danger_override_env_name:
        raise NotImplementedError(
            "haven't re-implemeneted env name override for multi-task GAIL")

    demos_metas_dict = get_demos_meta(demo_paths=demos,
                                      omit_noop=omit_noop,
                                      transfer_variants=transfer_variants,
                                      preproc_name=add_preproc)
    dataset_mt = demos_metas_dict['dataset_mt']
    variant_groups = demos_metas_dict['variant_groups']
    env_metas = demos_metas_dict['env_metas']
    task_ids_and_demo_env_names = demos_metas_dict[
        'task_ids_and_demo_env_names']
    task_var_weights = {
        (task, variant): 1.0 if variant == 0 else transfer_pol_batch_weight
        for task, variant in variant_groups.env_name_by_task_variant
    }
    sampler, sampler_batch_B = make_mux_sampler(
        variant_groups=variant_groups,
        task_var_weights=task_var_weights,
        env_metas=env_metas,
        use_gpu=use_gpu,
        num_demo_sources=0,  # not important for now
        batch_B=sampler_batch_B,
        batch_T=sampler_batch_T)

    policy_kwargs = {
        'use_bn': ppo_use_bn,
        'env_ids_and_names': task_ids_and_demo_env_names,
        **get_policy_spec_magical(env_metas),
    }
    policy_ctor = MultiHeadPolicyNet
    ppo_agent = CategoricalPgAgent(ModelCls=MuxTaskModelWrapper,
                                   model_kwargs=dict(
                                       model_ctor=policy_ctor,
                                       model_kwargs=policy_kwargs))

    print("Setting up discriminator/reward model")
    disc_fc_dim = 256
    disc_final_feats_dim = disc_al_dim if disc_al else disc_fc_dim
    discriminator_mt = MILBenchDiscriminatorMT(
        task_ids_and_names=task_ids_and_demo_env_names,
        in_chans=policy_kwargs['in_chans'],
        act_dim=policy_kwargs['n_actions'],
        use_all_chans=disc_all_frames,
        use_actions=disc_use_act,
        # can supply any argument that goes to MILBenchFeatureNetwork (e.g.
        # dropout, use_bn, width, etc.)
        attention=disc_net_attn,
        use_bn=disc_use_bn,
        use_sn=disc_use_sn,
        fc_dim=disc_fc_dim,
        final_feats_dim=disc_final_feats_dim,
    ).to(dev)

    if (not transfer_variants
            and (transfer_disc_loss_weight or transfer_pol_loss_weight)):
        print("No xfer variants supplied, setting xfer disc loss term to zero")
        transfer_disc_loss_weight = 0.0
        transfer_pol_loss_weight = 0.0
    if transfer_pol_loss_weight > 0:
        assert transfer_disc_loss_weight > 0
    if transfer_variants and transfer_disc_loss_weight:
        xfer_adv_module = BinaryDomainLossModule(
            discriminator_mt.ret_feats_dim).to(dev)
    else:
        xfer_adv_module = None

    reward_model_mt = RewardModel(
        discriminator_mt,
        xfer_adv_module,
        transfer_pol_loss_weight,
        # In apprenticeship learning we can just pass
        # the model outputs straight through, just
        # like in WGAN.
        use_wgan=wgan or disc_al).to(dev)
    reward_evaluator_mt = RewardEvaluatorMT(
        task_ids_and_names=task_ids_and_demo_env_names,
        reward_model=reward_model_mt,
        obs_dims=3,
        batch_size=disc_batch_size,
        normalise=True,
        # I think I had rewards in [0,0.01] in
        # the PPO run that I got to run with a
        # manually-defined reward.
        target_std=0.01)

    ppo_hyperparams = dict(
        learning_rate=ppo_lr,
        discount=ppo_gamma,
        entropy_loss_coeff=ppo_ent,  # was working at 0.003 and 0.001
        gae_lambda=ppo_lambda,
        ratio_clip=ppo_adv_clip,
        minibatches=ppo_minibatches,
        epochs=ppo_epochs,
        value_loss_coeff=1.0,
        clip_grad_norm=1.0,
        normalize_advantage=ppo_norm_adv,
    )
    if bc_loss:
        # TODO: make this batch size configurable
        ppo_loader_mt = make_loader_mt(
            dataset_mt, max(16, min(64, sampler_batch_T * sampler_batch_B)))
    else:
        ppo_loader_mt = None

    # FIXME: abstract code for constructing augmentation model from presets
    try:
        ppo_aug_opts = MILBenchAugmentations.PRESETS[ppo_aug]
    except KeyError:
        raise ValueError(f"unsupported augmentation mode '{ppo_aug}'")
    if ppo_aug_opts:
        print("Policy augmentations:", ", ".join(ppo_aug_opts))
        ppo_aug_model = MILBenchAugmentations(
            **{k: True
               for k in ppo_aug_opts}).to(dev)
    else:
        print("No policy augmentations")
        ppo_aug_model = None

    ppo_algo = BCCustomRewardPPO(bc_loss_coeff=bc_loss,
                                 expert_traj_loader=ppo_loader_mt,
                                 true_reward_weight=danger_debug_reward_weight,
                                 aug_model=ppo_aug_model,
                                 **ppo_hyperparams)
    ppo_algo.set_reward_evaluator(reward_evaluator_mt)

    print("Setting up optimiser")
    try:
        aug_opts = MILBenchAugmentations.PRESETS[disc_aug]
    except KeyError:
        raise ValueError(f"unsupported augmentation mode '{disc_aug}'")
    if aug_opts:
        print("Discriminator augmentations:", ", ".join(aug_opts))
        aug_model = MILBenchAugmentations(**{k: True for k in aug_opts}) \
            .to(dev)
    else:
        print("No discriminator augmentations")
        aug_model = None
    gail_optim = GAILOptimiser(
        dataset_mt=dataset_mt,
        discrim_model=discriminator_mt,
        buffer_num_samples=max(
            disc_batch_size,
            disc_replay_mult * sampler_batch_T * sampler_batch_B),
        batch_size=disc_batch_size,
        updates_per_itr=disc_up_per_iter,
        gp_weight=disc_gp_weight,
        dev=dev,
        aug_model=aug_model,
        lr=disc_lr,
        xfer_adv_weight=transfer_disc_loss_weight,
        xfer_adv_anneal=transfer_disc_anneal,
        xfer_adv_module=xfer_adv_module,
        final_layer_only_mode=disc_al,
        final_layer_only_mode_n_samples=disc_al_nsamples,
        use_wgan=wgan)

    if disc_ae_pretrain_iters:
        # FIXME(sam): pass n_acts, obs_chans, lr to AETrainer
        ae_trainer = AETrainer(discriminator=discriminator_mt,
                               disc_out_size=disc_final_feats_dim,
                               data_batch_iter=gail_optim.expert_batch_iter,
                               dev=dev)

    print("Setting up RL algorithm")
    # signature for arg: reward_model(obs_tensor, act_tensor) -> rewards
    runner = GAILMinibatchRl(
        seed=seed,
        gail_optim=gail_optim,
        variant_groups=variant_groups,
        algo=ppo_algo,
        agent=ppo_agent,
        sampler=sampler,
        # n_steps controls total number of environment steps we take
        n_steps=total_n_steps,
        # log_interval_steps controls how many environment steps we take
        # between making log outputs (doing N environment steps takes roughly
        # the same amount of time no matter what batch_B, batch_T, etc. are, so
        # this gives us a fairly constant interval between log outputs)
        log_interval_steps=log_interval_steps,
        affinity=affinity)

    # TODO: factor out this callback
    def init_policy_cb(runner):
        """Callback which gets called once after Runner startup to save an
        initial policy model, and optionally load saved parameters."""
        # get state of newly-initalised model
        wrapped_model = runner.algo.agent.model
        assert wrapped_model is not None, "has ppo_agent been initalised?"
        unwrapped_model = wrapped_model.model

        if load_policy:
            print(f"Loading policy from '{load_policy}'")
            saved_model = load_state_dict_or_model(load_policy)
            saved_dict = saved_model.state_dict()
            unwrapped_model.load_state_dict(saved_dict)

        real_state = unwrapped_model.state_dict()

        # make a clone model so we can pickle it, and copy across weights
        policy_copy_mt = policy_ctor(**policy_kwargs).to('cpu')
        policy_copy_mt.load_state_dict(real_state)

        # save it here
        init_pol_snapshot_path = os.path.join(logger.get_snapshot_dir(),
                                              'full_model.pt')
        torch.save(policy_copy_mt, init_pol_snapshot_path)

    print("Training!")
    n_uniq_envs = variant_groups.num_tasks
    log_params = {
        'add_preproc': add_preproc,
        'seed': seed,
        'sampler_batch_T': sampler_batch_T,
        'sampler_batch_B': sampler_batch_B,
        'disc_batch_size': disc_batch_size,
        'disc_up_per_iter': disc_up_per_iter,
        'total_n_steps': total_n_steps,
        'bc_loss': bc_loss,
        'omit_noop': omit_noop,
        'disc_aug': disc_aug,
        'danger_debug_reward_weight': danger_debug_reward_weight,
        'disc_lr': disc_lr,
        'disc_use_act': disc_use_act,
        'disc_all_frames': disc_all_frames,
        'disc_net_attn': disc_net_attn,
        'disc_use_bn': disc_use_bn,
        'ppo_lr': ppo_lr,
        'ppo_gamma': ppo_gamma,
        'ppo_lambda': ppo_lambda,
        'ppo_ent': ppo_ent,
        'ppo_adv_clip': ppo_adv_clip,
        'ppo_norm_adv': ppo_norm_adv,
        'transfer_variants': transfer_variants,
        'transfer_pol_batch_weight': transfer_pol_batch_weight,
        'transfer_pol_loss_weight': transfer_pol_loss_weight,
        'transfer_disc_loss_weight': transfer_disc_loss_weight,
        'transfer_disc_anneal': transfer_disc_anneal,
        'ndemos': len(demos),
        'n_uniq_envs': n_uniq_envs,
    }
    with make_logger_ctx(out_dir,
                         "mtgail",
                         f"mt{n_uniq_envs}",
                         run_name,
                         snapshot_gap=snapshot_gap,
                         log_params=log_params):
        torch.save(
            discriminator_mt,
            os.path.join(logger.get_snapshot_dir(), 'full_discrim_model.pt'))

        if disc_ae_pretrain_iters:
            # FIXME(sam): come up with a better solution for creating these
            # montages (can I do it regularly? Should I put them somewhere
            # other than the snapshot dir?).
            ae_trainer.make_montage(
                os.path.join(logger.get_snapshot_dir(), 'ae-before.png'))
            ae_trainer.do_full_training(disc_ae_pretrain_iters)
            ae_trainer.make_montage(
                os.path.join(logger.get_snapshot_dir(), 'ae-after.png'))

        # note that periodic snapshots get saved by GAILMiniBatchRl, thanks to
        # the overridden get_itr_snapshot() method
        runner.train(cb_startup=init_policy_cb)