def main(): all_games = [ # 'adventure', 'air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis', 'bank_heist', 'battle_zone', 'beam_rider', 'berzerk', 'bowling', 'boxing', 'breakout', 'carnival', 'centipede', 'chopper_command', 'crazy_climber', # 'defender', ## apparently, this is really broken # 'demon_attack', 'double_dunk', 'elevator_action', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar', 'hero', 'ice_hockey', 'jamesbond', 'journey_escape', 'kangaroo', 'krull', 'kung_fu_master', 'montezuma_revenge', 'ms_pacman', 'name_this_game', # 'phoenix', 'demon_attack', # 'pitfall', 'pong', 'pooyan', 'private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 'seaquest', 'skiing', 'solaris', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down', 'venture', 'video_pinball', 'wizard_of_wor', 'yars_revenge', 'zaxxon' ] all_games = [''.join((s.capitalize() for s in g.split('_'))) + '-v0' for g in all_games] monet_config = ConfigGenerator('/home/cvoelcker/thesis/master_experiments/MONetTraining/experiments/monet-baseline/run_003/config.yml') spatial_monet_config = ConfigGenerator('experiments/baseline-vae-simple/run_003/config.yml') vae_config = ConfigGenerator('experiments/baseline-vae-simple/run_003/config.yml') vae_config = ConfigGenerator('/home/cvoelcker/thesis/master_experiments/MONetTraining/experiments/monet-baseline/run_003/config.yml') monet_config = monet_config(argv[1:]) spatial_monet_config = spatial_monet_config(argv[1:]) # genesis_config = vae_config = vae_config(argv[1:]) for game in all_games: print() print() print(f'Running {game}') # monet_checkpoint = f'trained_baselines/monet/checkpoints_{game}/model_state_0000050.save' # spatial_monet_checkpoint = f'experiments/demon-attack/run_000/checkpoints_{game}/model_state_0000020.save' # genesis_checkpoint = f'trained_baselines/monet/checkpoints_{game}/model_state_0000050.save' monet_checkpoint = f'experiments/monet-baseline/run_003/checkpoints_{game}/model_state_0000100.save' monet = nn.DataParallel(Monet(**monet_config.MODULE._asdict())).cuda() monet.load_state_dict(torch.load(monet_checkpoint)) # spatial_monet = MaskedAIR(**spatial_monet_config.MODULE._asdict()).cuda() # # spatial_monet = nn.DataParallel(MaskedAIR(**spatial_monet_config.MODULE._asdict())).cuda() # spatial_monet.load_state_dict(torch.load(spatial_monet_checkpoint)) # genesis = GENESIS(**genesis_config._asdict()).cuda # genesis.load_state_dict(torch.load(genesis_checkpoint)) # vae = nn.DataParallel(BroadcastVAE(**vae_config.MODULE._asdict())).cuda() # vae.load_state_dict(torch.load(vae_checkpoint)) all_models = { 'monet': monet, # 'spatial_monet': spatial_monet, # 'genesis': genesis, # 'vae': vae } env = gym.make(game) source_loader = generators.FunctionLoader( lambda **kwargs: generate_envs_data(**kwargs)['X'].squeeze(), {'env': env, 'num_runs': 1, 'run_len': 100}) data_transformers = [ transformers.TorchVisionTransformerComposition(monet_config.DATA.transform, monet_config.DATA.shape), transformers.TypeTransformer(monet_config.EXPERIMENT.device) ] data = BasicDataSet(source_loader, data_transformers) evaluator = MONetEvaluator(data, all_models) losses, recons, imgs, masks = evaluator.evaluate() if not os.path.exists(f'eval-monet/{game}'): os.makedirs(f'eval-monet/{game}', exist_ok=True) pickle.dump(losses, open(f'eval-monet/{game}/mse.pkl', 'wb')) pickle.dump(recons, open(f'eval-monet/{game}/recons.pkl', 'wb')) pickle.dump(imgs, open(f'eval-monet/{game}/imgs.pkl', 'wb')) pickle.dump(masks, open(f'eval-monet/{game}/masks.pkl', 'wb'))
env = gym.make(game) source_loader = generators.FunctionLoader( lambda **kwargs: generate_envs_data(**kwargs)['X'].squeeze(), { 'env': env, 'num_runs': 1, 'run_len': 25000 }) data_transformers = [ transformers.TorchVisionTransformerComposition(config.DATA.transform, config.DATA.shape), transformers.TypeTransformer(config.EXPERIMENT.device) ] print('Loading data') data = BasicDataSet(source_loader, data_transformers) print('Setting up trainer') trainer = setup_trainer(MONetTrainer, monet, training_config, data) check_path = os.path.join(run_path, 'checkpoints_{}'.format(game)) if not os.path.exists(check_path): os.mkdir(check_path) checkpointing = file_handler.EpochCheckpointHandler(check_path) trainer.register_handler(checkpointing) log_path = os.path.join(run_path, 'logging_{}'.format(game)) if not os.path.exists(log_path): os.mkdir(log_path) tb_logger = tb_handler.NStepTbHandler( config.EXPERIMENT.log_every, run_path, 'logging_{}'.format(game), log_name_list=['loss', 'kl_loss', 'mask_loss', 'p_x_loss'])