Пример #1
0
def main():
    all_games = [
            # 'adventure', 'air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis', 'bank_heist', 'battle_zone', 'beam_rider', 'berzerk', 'bowling', 'boxing', 'breakout', 'carnival', 'centipede', 'chopper_command', 'crazy_climber', 
            # 'defender',  ## apparently, this is really broken
            # 'demon_attack', 'double_dunk', 'elevator_action', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar', 'hero', 'ice_hockey', 'jamesbond', 'journey_escape', 'kangaroo', 'krull', 'kung_fu_master', 'montezuma_revenge', 'ms_pacman', 'name_this_game', 
            # 'phoenix', 
            'demon_attack',
            # 'pitfall', 'pong', 'pooyan', 'private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 'seaquest', 'skiing', 'solaris', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down', 'venture', 'video_pinball', 'wizard_of_wor', 'yars_revenge', 'zaxxon'
            ]
            
    
    all_games = [''.join((s.capitalize() for s in g.split('_'))) + '-v0' for g in all_games]
    
    monet_config = ConfigGenerator('/home/cvoelcker/thesis/master_experiments/MONetTraining/experiments/monet-baseline/run_003/config.yml')
    spatial_monet_config = ConfigGenerator('experiments/baseline-vae-simple/run_003/config.yml')
    vae_config = ConfigGenerator('experiments/baseline-vae-simple/run_003/config.yml')
    vae_config = ConfigGenerator('/home/cvoelcker/thesis/master_experiments/MONetTraining/experiments/monet-baseline/run_003/config.yml')
    monet_config = monet_config(argv[1:])
    spatial_monet_config = spatial_monet_config(argv[1:])
    # genesis_config =
    vae_config = vae_config(argv[1:])
    for game in all_games:
        print()
        print()
        print(f'Running {game}')
    
        # monet_checkpoint = f'trained_baselines/monet/checkpoints_{game}/model_state_0000050.save'
        # spatial_monet_checkpoint = f'experiments/demon-attack/run_000/checkpoints_{game}/model_state_0000020.save'
        # genesis_checkpoint = f'trained_baselines/monet/checkpoints_{game}/model_state_0000050.save'
        monet_checkpoint = f'experiments/monet-baseline/run_003/checkpoints_{game}/model_state_0000100.save'
    
        monet = nn.DataParallel(Monet(**monet_config.MODULE._asdict())).cuda()
        monet.load_state_dict(torch.load(monet_checkpoint))
    
        # spatial_monet = MaskedAIR(**spatial_monet_config.MODULE._asdict()).cuda()
        # # spatial_monet = nn.DataParallel(MaskedAIR(**spatial_monet_config.MODULE._asdict())).cuda()
        # spatial_monet.load_state_dict(torch.load(spatial_monet_checkpoint))
    
        # genesis = GENESIS(**genesis_config._asdict()).cuda
        # genesis.load_state_dict(torch.load(genesis_checkpoint))
    
        # vae = nn.DataParallel(BroadcastVAE(**vae_config.MODULE._asdict())).cuda()
        # vae.load_state_dict(torch.load(vae_checkpoint))
    
        all_models = {
            'monet': monet,
            # 'spatial_monet': spatial_monet,
            # 'genesis': genesis,
            # 'vae': vae
        }
    
        env = gym.make(game)
        
        source_loader = generators.FunctionLoader(
                lambda **kwargs: generate_envs_data(**kwargs)['X'].squeeze(),
                {'env': env, 'num_runs': 1, 'run_len': 100})
    
        data_transformers = [
                transformers.TorchVisionTransformerComposition(monet_config.DATA.transform, monet_config.DATA.shape),
                transformers.TypeTransformer(monet_config.EXPERIMENT.device)
                ]
        data = BasicDataSet(source_loader, data_transformers)
    
        evaluator = MONetEvaluator(data, all_models)
        losses, recons, imgs, masks = evaluator.evaluate()
        if not os.path.exists(f'eval-monet/{game}'):
            os.makedirs(f'eval-monet/{game}', exist_ok=True)
        pickle.dump(losses, open(f'eval-monet/{game}/mse.pkl', 'wb'))
        pickle.dump(recons, open(f'eval-monet/{game}/recons.pkl', 'wb'))
        pickle.dump(imgs, open(f'eval-monet/{game}/imgs.pkl', 'wb'))
        pickle.dump(masks, open(f'eval-monet/{game}/masks.pkl', 'wb'))
    env = gym.make(game)

    source_loader = generators.FunctionLoader(
        lambda **kwargs: generate_envs_data(**kwargs)['X'].squeeze(), {
            'env': env,
            'num_runs': 1,
            'run_len': 25000
        })

    data_transformers = [
        transformers.TorchVisionTransformerComposition(config.DATA.transform,
                                                       config.DATA.shape),
        transformers.TypeTransformer(config.EXPERIMENT.device)
    ]
    print('Loading data')
    data = BasicDataSet(source_loader, data_transformers)
    print('Setting up trainer')
    trainer = setup_trainer(MONetTrainer, monet, training_config, data)
    check_path = os.path.join(run_path, 'checkpoints_{}'.format(game))
    if not os.path.exists(check_path):
        os.mkdir(check_path)
    checkpointing = file_handler.EpochCheckpointHandler(check_path)
    trainer.register_handler(checkpointing)
    log_path = os.path.join(run_path, 'logging_{}'.format(game))
    if not os.path.exists(log_path):
        os.mkdir(log_path)
    tb_logger = tb_handler.NStepTbHandler(
        config.EXPERIMENT.log_every,
        run_path,
        'logging_{}'.format(game),
        log_name_list=['loss', 'kl_loss', 'mask_loss', 'p_x_loss'])