def testall(state_dict_or_model_path, env_name, seed, fps, write_latex, latex_alg_name, n_rollouts, run_id, write_csv, gpu_idx, cpu_list, batch_size, load_latest): """Run quantitative evaluation on all test variants of a given environment.""" # TODO: is there some way of factoring this init code out? Maybe put into # Click base command so that it gets run for `train`, `testall`, etc. set_seeds(seed) import magical magical.register_envs() # for parallel GPU/CPU sampling mp.set_start_method('spawn') use_gpu = gpu_idx is not None and torch.cuda.is_available() dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu]) print(f"Using device {dev}, seed {seed}") if cpu_list is None: cpu_list = sample_cpu_list() affinity = dict(cuda_idx=gpu_idx if use_gpu else None, workers_cpus=cpu_list) if load_latest: state_dict_or_model_path = get_latest_path(state_dict_or_model_path) if run_id is None: run_id = state_dict_or_model_path eval_protocol = MTBCEvalProtocol( demo_env_name=env_name, state_dict_or_model_path=state_dict_or_model_path, seed=seed, # det_pol=det_pol, run_id=run_id, gpu_idx=gpu_idx, affinity=affinity, n_rollouts=n_rollouts, batch_size=batch_size, ) # next bit copied from testall() in bc.py frame = eval_protocol.do_eval(verbose=True) if latex_alg_name is None: latex_alg_name = run_id frame['latex_alg_name'] = latex_alg_name if write_latex: latex_str = latexify_results(frame, id_column='latex_alg_name') dir_path = os.path.dirname(write_latex) if dir_path: os.makedirs(dir_path, exist_ok=True) with open(write_latex, 'w') as fp: fp.write(latex_str) if write_csv: dir_path = os.path.dirname(write_csv) if dir_path: os.makedirs(dir_path, exist_ok=True) frame.to_csv(write_csv) return frame
def main(demo, save_dir, preproc): magical.register_envs() demo_dict, = load_demos([demo]) # keys of the demo are env_name, trajectory, score orig_env_name = demo_dict['env_name'] if save_dir: os.makedirs(save_dir, exist_ok=True) print(f"Will save all demos to '{save_dir}'") for preproc_name in preproc: preprocessed, = preprocess_demos_with_wrapper( [demo_dict['trajectory']], orig_env_name, preproc_name=preproc_name) print(f"Working on preprocessor '{preproc_name}'") for frame_idx, frame in enumerate(preprocessed.obs): frame_fname = 'frame-' + preproc_name.lower() \ + f'-{frame_idx:03}.png' frame_path = os.path.join(save_dir, frame_fname) assert frame.shape[-1] % 3 == 0 frame_stacked = frame.reshape(frame.shape[:-1] + (-1, 3)) frame_trans = frame_stacked.transpose((2, 0, 1, 3)) frame_wide = np.concatenate(frame_trans, axis=1) print(f" Writing frame to '{frame_path}'") # make much larger so we can see pixel boundaries frame_wide = np.repeat(np.repeat(frame_wide, 8, axis=0), 8, axis=1) imageio.imsave(frame_path, frame_wide)
def process_directories(model_dirs, out_dir, gpu_idx=0): """Load the latest snapshot in each of the given model directories, and render a video for each test variant corresponding to the train variants in the model. Note that the video will show all models in sequence.""" set_seeds(DEFAULT_SEED) register_envs() if out_dir: os.makedirs(out_dir, exist_ok=True) if gpu_idx is None: dev = torch.device('cpu') else: dev = torch.device(f'cuda:{gpu_idx}') # we're going to keep a video writer for each env writers_by_env = {} for model_idx, model_dir in enumerate(model_dirs): print(f"Loading model from '{model_dir}' " f"(model {model_idx+1}/{len(model_dirs)})") latest_path = get_latest_path(os.path.join(model_dir, "itr_LATEST.pkl")) model = load_state_dict_or_model(latest_path) model = model.to(dev) test_envs = get_test_envs(model.env_ids_and_names) for test_env in test_envs: if test_env not in writers_by_env: vid_name = generate_vid_name(test_env) out_path = os.path.join(out_dir, vid_name) print(f" Writing video for '{test_env}' to '{out_path}'") writer = vidio.FFmpegWriter(out_path, outputdict={ '-r': str(DEFAULT_FPS), '-vcodec': 'libx264', '-pix_fmt': 'yuv420p', }) writers_by_env[test_env] = writer else: print(f" Writing video for '{test_env}'") writer = writers_by_env.get(test_env) # will be something like, e.g., "movetocorner-testjitter.mp4" for frame in generate_frames(model=model, env_name=test_env, dev=dev, seed=DEFAULT_SEED + model_idx, ntraj=NTRAJ_PER_MODEL, fps=DEFAULT_FPS): writer.writeFrame(frame) for writer in writers_by_env.values(): writer.close()
def _get_env_meta_target(env_names, rv_dict): register_envs() # in case this proc was spawned metas = [] for env_name in env_names: # construct a bunch of envs in turn to get info about their observation # spaces, action spaces, etc. env = gym.make(env_name) spec = FilteredSpec(*(getattr(env.spec, field) for field in FilteredSpec._fields)) meta = EnvMeta(observation_space=env.observation_space, action_space=env.action_space, spec=spec) metas.append(meta) env.close() rv_dict['result'] = tuple(metas)
def test(state_dict_or_model_path, env_name, det_pol, seed, fps, transfer_to): """Repeatedly roll out a policy on a given environment. Mostly useful for visual debugging; see `testall` for quantitative evaluation.""" set_seeds(seed) import magical magical.register_envs() # build env if transfer_to: env = gym.make(transfer_to) else: env = gym.make(env_name) model = load_state_dict_or_model(state_dict_or_model_path) ft_wrapper = wrap_model_for_fixed_task(model, env_name) spf = 1.0 / (env.fps if fps is None else fps) act_range = np.arange(env.action_space.n) obs = env.reset() try: while env.viewer.isopen: # for limiting FPS frame_start = time.time() # return value is actions, values, states, neglogp torch_obs = torch.from_numpy(obs) with torch.no_grad(): (pi_torch, ), _ = ft_wrapper(torch_obs[None], None, None) pi = pi_torch.cpu().numpy() if det_pol: action = np.argmax(pi) else: # numpy is super complain-y about things "not summing to 1" pi = pi / sum(pi) action = np.random.choice(act_range, p=pi) obs, rew, done, info = env.step(action) obs = np.asarray(obs) env.render(mode='human') if done: print(f"Done, score {info['eval_score']:.4g}/1.0") obs = env.reset() elapsed = time.time() - frame_start if elapsed < spf: time.sleep(spf - elapsed) finally: env.viewer.close()
def main(ntraj, env_name, seed): """Very simple script to benchmark performance of a particular environment. Will simply run a ~hundred trajectories or so and record performance.""" magical.register_envs() env = gym.make(env_name) env.seed(seed) env.action_space.seed(seed) dtime = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") out_filename = f'profile-{env_name.lower()}-{dtime}.cprofile' print(f"Will write profile to '{out_filename}'") cProfile.runctx('do_eval(env, ntraj)', globals(), locals(), filename=out_filename) print("Done")
def main(model_dirs, out_root, nprocs, job_ngpus): """Load the latest snapshot in `model-dir` (if any), and render a video for each test variant corresponding to the train variants in the model.""" register_envs() ray.init(num_cpus=nprocs) remote_process_dirs = ray.remote(num_gpus=job_ngpus)(process_directories) grouped_dirs = collections.OrderedDict() for model_dir in sorted(set(model_dirs)): base_dir = os.path.basename(model_dir.strip('/')) parsed_alg = parse_run_dir(base_dir) grouped_dirs.setdefault(parsed_alg, []).append(model_dir) handles = [] for group_alg, model_dirs in grouped_dirs.items(): out_dir = os.path.join(out_root, group_alg) print(f"Writing videos for group '{group_alg}' ({len(model_dirs)} " f"directories) to '{out_dir}'") remote_handle = remote_process_dirs.remote(model_dirs, out_dir) handles.append(remote_handle) ray.get(handles)
def train(demos, add_preproc, seed, batch_size, total_n_batches, eval_every_n_batches, out_dir, run_name, gpu_idx, cpu_list, eval_n_traj, snapshot_gap, omit_noop, net_width_mul, net_use_bn, net_dropout, net_coord_conv, net_attention, net_task_spec_layers, load_policy, aug_mode, min_bc): # TODO: abstract setup code. Seeds & GPUs should go in one function. Env # setup should go in another function (or maybe the same function). Dataset # loading should be simplified by having a single class that can provide # whatever form of data the current IL method needs, without having to do # unnecessary copies in memory. Maybe also just use Sacred, because YOLO. with contextlib.ExitStack() as exit_stack: # set up seeds & devices set_seeds(seed) mp.set_start_method('spawn') use_gpu = gpu_idx is not None and torch.cuda.is_available() dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu]) print(f"Using device {dev}, seed {seed}") if cpu_list is None: cpu_list = sample_cpu_list() affinity = dict( cuda_idx=gpu_idx if use_gpu else None, workers_cpus=cpu_list, ) # register original envs import magical magical.register_envs() # TODO: split out part of the dataset for validation. demos_metas_dict = get_demos_meta(demo_paths=demos, omit_noop=omit_noop, transfer_variants=[], preproc_name=add_preproc) dataset_mt = demos_metas_dict['dataset_mt'] loader_mt = make_loader_mt(dataset_mt, batch_size) variant_groups = demos_metas_dict['variant_groups'] env_metas = demos_metas_dict['env_metas'] num_demo_sources = demos_metas_dict['num_demo_sources'] task_ids_and_demo_env_names = demos_metas_dict[ 'task_ids_and_demo_env_names'] sampler_batch_B = batch_size # this doesn't really matter sampler_batch_T = 5 sampler, sampler_batch_B = make_mux_sampler( variant_groups=variant_groups, num_demo_sources=num_demo_sources, env_metas=env_metas, use_gpu=use_gpu, batch_B=sampler_batch_B, batch_T=sampler_batch_T, # TODO: instead of doing this, try sampling in proportion to length # of horizon; that should get more samples from harder envs task_var_weights=None) if load_policy is not None: try: pol_path = get_latest_path(load_policy) except ValueError: pol_path = load_policy policy_ctor = functools.partial( adapt_pol_loader, pol_path=pol_path, task_ids_and_demo_env_names=task_ids_and_demo_env_names) policy_kwargs = {} else: policy_kwargs = { 'env_ids_and_names': task_ids_and_demo_env_names, 'width': net_width_mul, 'use_bn': net_use_bn, 'dropout': net_dropout, 'coord_conv': net_coord_conv, 'attention': net_attention, 'n_task_spec_layers': net_task_spec_layers, **get_policy_spec_magical(env_metas), } policy_ctor = MultiHeadPolicyNet agent = CategoricalPgAgent(ModelCls=MuxTaskModelWrapper, model_kwargs=dict( model_ctor=policy_ctor, model_kwargs=policy_kwargs)) sampler.initialize(agent=agent, seed=np.random.randint(1 << 31), affinity=affinity) exit_stack.callback(lambda: sampler.shutdown()) model_mt = policy_ctor(**policy_kwargs).to(dev) if min_bc: num_tasks = len(task_ids_and_demo_env_names) weight_mod = MinBCWeightingModule(num_tasks, num_demo_sources) \ .to(dev) all_params = it.chain(model_mt.parameters(), weight_mod.parameters()) else: weight_mod = None all_params = model_mt.parameters() # Adam mostly works fine, but in very loose informal tests it seems # like SGD had fewer weird failures where mean loss would jump up by a # factor of 2x for a period (?). (I don't think that was solely due to # high LR; probably an architectural issue.) opt_mt = # torch.optim.Adam(model_mt.parameters(), lr=3e-4) opt_mt = torch.optim.SGD(all_params, lr=1e-3, momentum=0.1) try: aug_opts = MILBenchAugmentations.PRESETS[aug_mode] except KeyError: raise ValueError(f"unsupported mode '{aug_mode}'") if aug_opts: print("Augmentations:", ", ".join(aug_opts)) aug_model = MILBenchAugmentations(**{k: True for k in aug_opts}) \ .to(dev) else: print("No augmentations") aug_model = None n_uniq_envs = len(task_ids_and_demo_env_names) log_params = { 'n_uniq_envs': n_uniq_envs, 'n_demos': len(demos), 'net_use_bn': net_use_bn, 'net_width_mul': net_width_mul, 'net_dropout': net_dropout, 'net_coord_conv': net_coord_conv, 'net_attention': net_attention, 'aug_mode': aug_mode, 'seed': seed, 'omit_noop': omit_noop, 'batch_size': batch_size, 'eval_n_traj': eval_n_traj, 'eval_every_n_batches': eval_every_n_batches, 'total_n_batches': total_n_batches, 'snapshot_gap': snapshot_gap, 'add_preproc': add_preproc, 'net_task_spec_layers': net_task_spec_layers, } with make_logger_ctx(out_dir, "mtbc", f"mt{n_uniq_envs}", run_name, snapshot_gap=snapshot_gap, log_params=log_params): # initial save torch.save( model_mt, os.path.join(logger.get_snapshot_dir(), 'full_model.pt')) # train for a while n_batches_done = 0 n_rounds = int(np.ceil(total_n_batches / eval_every_n_batches)) rnd = 1 assert eval_every_n_batches > 0 while n_batches_done < total_n_batches: batches_left_now = min(total_n_batches - n_batches_done, eval_every_n_batches) print(f"Done {n_batches_done}/{total_n_batches} " f"({n_batches_done/total_n_batches*100:.2f}%, " f"{rnd}/{n_rounds} rounds) batches; doing another " f"{batches_left_now}") model_mt.train() loss_ewma, losses, per_task_losses = do_training_mt( loader=loader_mt, model=model_mt, opt=opt_mt, dev=dev, aug_model=aug_model, min_bc_module=weight_mod, n_batches=batches_left_now) # TODO: record accuracy on a random subset of the train and # validation sets (both in eval mode, not train mode) print(f"Evaluating {eval_n_traj} trajectories on " f"{variant_groups.num_tasks} tasks") record_misc_calls = [] model_mt.eval() copy_model_into_agent_eval(model_mt, sampler.agent) scores_by_tv = eval_model( sampler, # shouldn't be any exploration itr=0, n_traj=eval_n_traj) for (task_id, variant_id), scores in scores_by_tv.items(): tv_id = (task_id, variant_id) env_name = variant_groups.env_name_by_task_variant[tv_id] tag = make_env_tag(strip_mb_preproc_name(env_name)) logger.record_tabular_misc_stat("Score%s" % tag, scores) env_losses = per_task_losses.get(tv_id, []) record_misc_calls.append((f"Loss{tag}", env_losses)) # we record score AFTER loss so that losses are all in one # place, and scores are all in another for args in record_misc_calls: logger.record_tabular_misc_stat(*args) # finish logging for this epoch logger.record_tabular("Round", rnd) logger.record_tabular("LossEWMA", loss_ewma) logger.record_tabular_misc_stat("Loss", losses) logger.dump_tabular() logger.save_itr_params( rnd, { 'model_state': model_mt.state_dict(), 'opt_state': opt_mt.state_dict(), }) # advance ctrs rnd += 1 n_batches_done += batches_left_now
def __init__(self, env_name, **kwargs): register_envs() env = gym.make(env_name) super().__init__(env, **kwargs)
"""Test rollouts in every environment.""" import gym import pytest import magical N_ROLLOUTS = 2 magical.register_envs() def test_registered_envs(): # make sure we registered at least some environments assert len(magical.ALL_REGISTERED_ENVS) > 8 @pytest.mark.parametrize('env_name', magical.ALL_REGISTERED_ENVS) def test_rollouts(env_name): """Simple smoke test to make sure environments can roll out trajectories of the right length.""" env = gym.make(env_name) try: env.seed(7) env.action_space.seed(42) env.reset() for _ in range(N_ROLLOUTS): done = False traj_len = 0 while not done: action = env.action_space.sample() obs, rew, done, info = env.step(action)
def main( demos, add_preproc, seed, sampler_batch_B, sampler_batch_T, disc_batch_size, out_dir, run_name, gpu_idx, disc_up_per_iter, total_n_steps, log_interval_steps, cpu_list, snapshot_gap, load_policy, bc_loss, omit_noop, disc_replay_mult, disc_aug, ppo_aug, disc_use_bn, disc_net_attn, disc_use_sn, disc_gp_weight, disc_al, disc_al_dim, disc_al_nsamples, disc_ae_pretrain_iters, wgan, transfer_variants, transfer_disc_loss_weight, transfer_pol_loss_weight, transfer_disc_anneal, transfer_pol_batch_weight, danger_debug_reward_weight, danger_override_env_name, # new sweep hyperparams: disc_lr, disc_use_act, disc_all_frames, ppo_lr, ppo_gamma, ppo_lambda, ppo_ent, ppo_adv_clip, ppo_norm_adv, ppo_use_bn, ppo_minibatches, ppo_epochs): # set up seeds & devices # TODO: also seed child envs, when rlpyt supports it set_seeds(seed) # 'spawn' is necessary to use GL envs in subprocesses. For whatever reason # they don't play nice after a fork. (But what about set_seeds() in # subprocesses? May need to hack CpuSampler and GpuSampler.) mp.set_start_method('spawn') use_gpu = gpu_idx is not None and torch.cuda.is_available() dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu]) if cpu_list is None: cpu_list = sample_cpu_list() # FIXME: I suspect current solution will set torch_num_threads suboptimally affinity = dict(cuda_idx=gpu_idx if use_gpu else None, workers_cpus=cpu_list) print(f"Using device {dev}, seed {seed}, affinity {affinity}") # register original envs import magical magical.register_envs() if danger_override_env_name: raise NotImplementedError( "haven't re-implemeneted env name override for multi-task GAIL") demos_metas_dict = get_demos_meta(demo_paths=demos, omit_noop=omit_noop, transfer_variants=transfer_variants, preproc_name=add_preproc) dataset_mt = demos_metas_dict['dataset_mt'] variant_groups = demos_metas_dict['variant_groups'] env_metas = demos_metas_dict['env_metas'] task_ids_and_demo_env_names = demos_metas_dict[ 'task_ids_and_demo_env_names'] task_var_weights = { (task, variant): 1.0 if variant == 0 else transfer_pol_batch_weight for task, variant in variant_groups.env_name_by_task_variant } sampler, sampler_batch_B = make_mux_sampler( variant_groups=variant_groups, task_var_weights=task_var_weights, env_metas=env_metas, use_gpu=use_gpu, num_demo_sources=0, # not important for now batch_B=sampler_batch_B, batch_T=sampler_batch_T) policy_kwargs = { 'use_bn': ppo_use_bn, 'env_ids_and_names': task_ids_and_demo_env_names, **get_policy_spec_magical(env_metas), } policy_ctor = MultiHeadPolicyNet ppo_agent = CategoricalPgAgent(ModelCls=MuxTaskModelWrapper, model_kwargs=dict( model_ctor=policy_ctor, model_kwargs=policy_kwargs)) print("Setting up discriminator/reward model") disc_fc_dim = 256 disc_final_feats_dim = disc_al_dim if disc_al else disc_fc_dim discriminator_mt = MILBenchDiscriminatorMT( task_ids_and_names=task_ids_and_demo_env_names, in_chans=policy_kwargs['in_chans'], act_dim=policy_kwargs['n_actions'], use_all_chans=disc_all_frames, use_actions=disc_use_act, # can supply any argument that goes to MILBenchFeatureNetwork (e.g. # dropout, use_bn, width, etc.) attention=disc_net_attn, use_bn=disc_use_bn, use_sn=disc_use_sn, fc_dim=disc_fc_dim, final_feats_dim=disc_final_feats_dim, ).to(dev) if (not transfer_variants and (transfer_disc_loss_weight or transfer_pol_loss_weight)): print("No xfer variants supplied, setting xfer disc loss term to zero") transfer_disc_loss_weight = 0.0 transfer_pol_loss_weight = 0.0 if transfer_pol_loss_weight > 0: assert transfer_disc_loss_weight > 0 if transfer_variants and transfer_disc_loss_weight: xfer_adv_module = BinaryDomainLossModule( discriminator_mt.ret_feats_dim).to(dev) else: xfer_adv_module = None reward_model_mt = RewardModel( discriminator_mt, xfer_adv_module, transfer_pol_loss_weight, # In apprenticeship learning we can just pass # the model outputs straight through, just # like in WGAN. use_wgan=wgan or disc_al).to(dev) reward_evaluator_mt = RewardEvaluatorMT( task_ids_and_names=task_ids_and_demo_env_names, reward_model=reward_model_mt, obs_dims=3, batch_size=disc_batch_size, normalise=True, # I think I had rewards in [0,0.01] in # the PPO run that I got to run with a # manually-defined reward. target_std=0.01) ppo_hyperparams = dict( learning_rate=ppo_lr, discount=ppo_gamma, entropy_loss_coeff=ppo_ent, # was working at 0.003 and 0.001 gae_lambda=ppo_lambda, ratio_clip=ppo_adv_clip, minibatches=ppo_minibatches, epochs=ppo_epochs, value_loss_coeff=1.0, clip_grad_norm=1.0, normalize_advantage=ppo_norm_adv, ) if bc_loss: # TODO: make this batch size configurable ppo_loader_mt = make_loader_mt( dataset_mt, max(16, min(64, sampler_batch_T * sampler_batch_B))) else: ppo_loader_mt = None # FIXME: abstract code for constructing augmentation model from presets try: ppo_aug_opts = MILBenchAugmentations.PRESETS[ppo_aug] except KeyError: raise ValueError(f"unsupported augmentation mode '{ppo_aug}'") if ppo_aug_opts: print("Policy augmentations:", ", ".join(ppo_aug_opts)) ppo_aug_model = MILBenchAugmentations( **{k: True for k in ppo_aug_opts}).to(dev) else: print("No policy augmentations") ppo_aug_model = None ppo_algo = BCCustomRewardPPO(bc_loss_coeff=bc_loss, expert_traj_loader=ppo_loader_mt, true_reward_weight=danger_debug_reward_weight, aug_model=ppo_aug_model, **ppo_hyperparams) ppo_algo.set_reward_evaluator(reward_evaluator_mt) print("Setting up optimiser") try: aug_opts = MILBenchAugmentations.PRESETS[disc_aug] except KeyError: raise ValueError(f"unsupported augmentation mode '{disc_aug}'") if aug_opts: print("Discriminator augmentations:", ", ".join(aug_opts)) aug_model = MILBenchAugmentations(**{k: True for k in aug_opts}) \ .to(dev) else: print("No discriminator augmentations") aug_model = None gail_optim = GAILOptimiser( dataset_mt=dataset_mt, discrim_model=discriminator_mt, buffer_num_samples=max( disc_batch_size, disc_replay_mult * sampler_batch_T * sampler_batch_B), batch_size=disc_batch_size, updates_per_itr=disc_up_per_iter, gp_weight=disc_gp_weight, dev=dev, aug_model=aug_model, lr=disc_lr, xfer_adv_weight=transfer_disc_loss_weight, xfer_adv_anneal=transfer_disc_anneal, xfer_adv_module=xfer_adv_module, final_layer_only_mode=disc_al, final_layer_only_mode_n_samples=disc_al_nsamples, use_wgan=wgan) if disc_ae_pretrain_iters: # FIXME(sam): pass n_acts, obs_chans, lr to AETrainer ae_trainer = AETrainer(discriminator=discriminator_mt, disc_out_size=disc_final_feats_dim, data_batch_iter=gail_optim.expert_batch_iter, dev=dev) print("Setting up RL algorithm") # signature for arg: reward_model(obs_tensor, act_tensor) -> rewards runner = GAILMinibatchRl( seed=seed, gail_optim=gail_optim, variant_groups=variant_groups, algo=ppo_algo, agent=ppo_agent, sampler=sampler, # n_steps controls total number of environment steps we take n_steps=total_n_steps, # log_interval_steps controls how many environment steps we take # between making log outputs (doing N environment steps takes roughly # the same amount of time no matter what batch_B, batch_T, etc. are, so # this gives us a fairly constant interval between log outputs) log_interval_steps=log_interval_steps, affinity=affinity) # TODO: factor out this callback def init_policy_cb(runner): """Callback which gets called once after Runner startup to save an initial policy model, and optionally load saved parameters.""" # get state of newly-initalised model wrapped_model = runner.algo.agent.model assert wrapped_model is not None, "has ppo_agent been initalised?" unwrapped_model = wrapped_model.model if load_policy: print(f"Loading policy from '{load_policy}'") saved_model = load_state_dict_or_model(load_policy) saved_dict = saved_model.state_dict() unwrapped_model.load_state_dict(saved_dict) real_state = unwrapped_model.state_dict() # make a clone model so we can pickle it, and copy across weights policy_copy_mt = policy_ctor(**policy_kwargs).to('cpu') policy_copy_mt.load_state_dict(real_state) # save it here init_pol_snapshot_path = os.path.join(logger.get_snapshot_dir(), 'full_model.pt') torch.save(policy_copy_mt, init_pol_snapshot_path) print("Training!") n_uniq_envs = variant_groups.num_tasks log_params = { 'add_preproc': add_preproc, 'seed': seed, 'sampler_batch_T': sampler_batch_T, 'sampler_batch_B': sampler_batch_B, 'disc_batch_size': disc_batch_size, 'disc_up_per_iter': disc_up_per_iter, 'total_n_steps': total_n_steps, 'bc_loss': bc_loss, 'omit_noop': omit_noop, 'disc_aug': disc_aug, 'danger_debug_reward_weight': danger_debug_reward_weight, 'disc_lr': disc_lr, 'disc_use_act': disc_use_act, 'disc_all_frames': disc_all_frames, 'disc_net_attn': disc_net_attn, 'disc_use_bn': disc_use_bn, 'ppo_lr': ppo_lr, 'ppo_gamma': ppo_gamma, 'ppo_lambda': ppo_lambda, 'ppo_ent': ppo_ent, 'ppo_adv_clip': ppo_adv_clip, 'ppo_norm_adv': ppo_norm_adv, 'transfer_variants': transfer_variants, 'transfer_pol_batch_weight': transfer_pol_batch_weight, 'transfer_pol_loss_weight': transfer_pol_loss_weight, 'transfer_disc_loss_weight': transfer_disc_loss_weight, 'transfer_disc_anneal': transfer_disc_anneal, 'ndemos': len(demos), 'n_uniq_envs': n_uniq_envs, } with make_logger_ctx(out_dir, "mtgail", f"mt{n_uniq_envs}", run_name, snapshot_gap=snapshot_gap, log_params=log_params): torch.save( discriminator_mt, os.path.join(logger.get_snapshot_dir(), 'full_discrim_model.pt')) if disc_ae_pretrain_iters: # FIXME(sam): come up with a better solution for creating these # montages (can I do it regularly? Should I put them somewhere # other than the snapshot dir?). ae_trainer.make_montage( os.path.join(logger.get_snapshot_dir(), 'ae-before.png')) ae_trainer.do_full_training(disc_ae_pretrain_iters) ae_trainer.make_montage( os.path.join(logger.get_snapshot_dir(), 'ae-after.png')) # note that periodic snapshots get saved by GAILMiniBatchRl, thanks to # the overridden get_itr_snapshot() method runner.train(cb_startup=init_policy_cb)