def train_encoder(args):
    device = torch.device(
        "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu")
    tr_eps, val_eps = get_episodes(args,
                                   device,
                                   collect_mode=args.collect_mode,
                                   train_mode="train_encoder",
                                   seed=args.seed)

    observation_shape = tr_eps[0][0].shape
    if args.encoder_type == "Nature":
        encoder = NatureCNN(observation_shape[0], args)
    elif args.encoder_type == "Impala":
        encoder = ImpalaCNN(observation_shape[0], args)
    encoder.to(device)
    torch.set_num_threads(1)

    config = {}
    config.update(vars(args))
    config['obs_space'] = observation_shape  # weird hack
    if args.method == 'cpc':
        trainer = CPCTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == 'spatial-appo':
        trainer = SpatioTemporalTrainer(encoder,
                                        config,
                                        device=device,
                                        wandb=wandb)
    elif args.method == 'vae':
        trainer = VAETrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "naff":
        trainer = NaFFPredictorTrainer(encoder,
                                       config,
                                       device=device,
                                       wandb=wandb)
    elif args.method == "infonce-stdim":
        trainer = InfoNCESpatioTemporalTrainer(encoder,
                                               config,
                                               device=device,
                                               wandb=wandb)
    elif args.method == "global-infonce-stdim":
        trainer = GlobalInfoNCESpatioTemporalTrainer(encoder,
                                                     config,
                                                     device=device,
                                                     wandb=wandb)
    elif args.method == "global-local-infonce-stdim":
        trainer = GlobalLocalInfoNCESpatioTemporalTrainer(encoder,
                                                          config,
                                                          device=device,
                                                          wandb=wandb)
    else:
        assert False, "method {} has no trainer".format(args.method)

    trainer.train(tr_eps, val_eps)

    return encoder
def train_encoder(args):
    device = torch.device(
        "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu")
    tr_eps, val_eps = get_episodes(steps=args.pretraining_steps,
                                   env_name=args.env_name,
                                   seed=args.seed,
                                   num_processes=args.num_processes,
                                   num_frame_stack=args.num_frame_stack,
                                   downsample=not args.no_downsample,
                                   color=args.color,
                                   entropy_threshold=args.entropy_threshold,
                                   collect_mode=args.probe_collect_mode,
                                   train_mode="train_encoder",
                                   checkpoint_index=args.checkpoint_index,
                                   min_episode_length=args.batch_size)

    observation_shape = tr_eps[0][0].shape
    if args.encoder_type == "Nature":
        encoder = NatureCNN(observation_shape[0], args)
    elif args.encoder_type == "Impala":
        encoder = ImpalaCNN(observation_shape[0], args)
    encoder.to(device)
    torch.set_num_threads(1)

    config = {}
    config.update(vars(args))
    config['obs_space'] = observation_shape  # weird hack
    if args.method == 'cpc':
        trainer = CPCTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == 'spatial-appo':
        trainer = SpatioTemporalTrainer(encoder,
                                        config,
                                        device=device,
                                        wandb=wandb)
    elif args.method == 'vae':
        trainer = VAETrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "naff":
        trainer = NaFFPredictorTrainer(encoder,
                                       config,
                                       device=device,
                                       wandb=wandb)
    elif args.method == "infonce-stdim":
        trainer = InfoNCESpatioTemporalTrainer(encoder,
                                               config,
                                               device=device,
                                               wandb=wandb)
    elif args.method == "global-infonce-stdim":
        trainer = GlobalInfoNCESpatioTemporalTrainer(encoder,
                                                     config,
                                                     device=device,
                                                     wandb=wandb)
    elif args.method == "global-local-infonce-stdim":
        trainer = GlobalLocalInfoNCESpatioTemporalTrainer(encoder,
                                                          config,
                                                          device=device,
                                                          wandb=wandb)
    elif args.method == "dim":
        trainer = DIMTrainer(encoder, config, device=device, wandb=wandb)
    else:
        assert False, "method {} has no trainer".format(args.method)

    trainer.train(tr_eps, val_eps)

    return encoder
Esempio n. 3
0
def run_probe(args):
    wandb.config.update(vars(args))
    tr_eps, val_eps, tr_labels, val_labels, test_eps, test_labels = get_episodes(steps=args.probe_steps,
                                                                                 env_name=args.env_name,
                                                                                 seed=args.seed,
                                                                                 num_processes=args.num_processes,
                                                                                 num_frame_stack=args.num_frame_stack,
                                                                                 downsample=not args.no_downsample,
                                                                                 color=args.color,
                                                                                 entropy_threshold=args.entropy_threshold,
                                                                                 collect_mode=args.probe_collect_mode,
                                                                                 train_mode="probe",
                                                                                 checkpoint_index=args.checkpoint_index,
                                                                                 min_episode_length=args.batch_size)

    print("got episodes!")

    if args.train_encoder and args.method in train_encoder_methods:
        print("Training encoder from scratch")
        encoder = train_encoder(args)
        encoder.probing = True
        encoder.eval()


    elif args.method in ["pretrained-rl-agent", "majority"]:
        encoder = None

    else:
        observation_shape = tr_eps[0][0].shape
        if args.encoder_type == "Nature":
            encoder = NatureCNN(observation_shape[0], args)
        elif args.encoder_type == "Impala":
            encoder = ImpalaCNN(observation_shape[0], args)

        if args.weights_path == "None":
            if args.method not in probe_only_methods:
                sys.stderr.write("Probing without loading in encoder weights! Are sure you want to do that??")
        else:
            print("Print loading in encoder weights from probe of type {} from the following path: {}"
                  .format(args.method, args.weights_path))
            encoder.load_state_dict(torch.load(args.weights_path))
            encoder.eval()

    torch.set_num_threads(1)

    if args.method == 'majority':
        test_acc, test_f1score = majority_baseline(tr_labels, test_labels, wandb)

    else:
        trainer = ProbeTrainer(encoder=encoder,
                               epochs=args.epochs,
                               method_name=args.method,
                               lr=args.probe_lr,
                               batch_size=args.batch_size,
                               patience=args.patience,
                               wandb=wandb,
                               fully_supervised=(args.method == "supervised"),
                               save_dir=wandb.run.dir)

        trainer.train(tr_eps, val_eps, tr_labels, val_labels)
        test_acc, test_f1score = trainer.test(test_eps, test_labels)

    print(test_acc, test_f1score)
    wandb.log(test_acc)
    wandb.log(test_f1score)
def train_encoder(args):
    device = torch.device(
        "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu")
    envs = [
        'AsteroidsNoFrameskip-v4', 'BerzerkNoFrameskip-v4',
        'BowlingNoFrameskip-v4', 'BoxingNoFrameskip-v4',
        'BreakoutNoFrameskip-v4', 'DemonAttackNoFrameskip-v4',
        'FreewayNoFrameskip-v4', 'FrostbiteNoFrameskip-v4',
        'HeroNoFrameskip-v4', 'MontezumaRevengeNoFrameskip-v4',
        'MsPacmanNoFrameskip-v4', 'PitfallNoFrameskip-v4',
        'PongNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4',
        'QbertNoFrameskip-v4', 'RiverraidNoFrameskip-v4',
        'SeaquestNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4',
        'TennisNoFrameskip-v4', 'VentureNoFrameskip-v4',
        'VideoPinballNoFrameskip-v4', 'YarsRevengeNoFrameskip-v4'
    ]

    all_tr_episodes, all_val_episodes = [], []
    pos_tr_eps, pos_val_eps = [], []
    for env in envs:
        steps = 5000
        if env == args.env_name:
            steps = 50000
        tr_eps, val_eps = get_episodes(
            steps=steps,
            env_name=env,
            seed=args.seed,
            num_processes=args.num_processes,
            num_frame_stack=args.num_frame_stack,
            downsample=not args.no_downsample,
            color=args.color,
            entropy_threshold=args.entropy_threshold,
            collect_mode=args.probe_collect_mode,
            train_mode="train_encoder",
            checkpoint_index=args.checkpoint_index,
            min_episode_length=args.batch_size)
        if env == args.env_name:
            pos_tr_eps = tr_eps
            pos_val_eps = val_eps
        else:
            all_tr_episodes.append(tr_eps)
            all_val_episodes.append(val_eps)

    all_tr_episodes = list(chain.from_iterable(all_tr_episodes))
    all_val_episodes = list(chain.from_iterable(all_val_episodes))
    observation_shape = pos_tr_eps[0][0].shape
    if args.encoder_type == "Nature":
        encoder = NatureCNN(observation_shape[0], args)
    elif args.encoder_type == "Impala":
        encoder = ImpalaCNN(observation_shape[0], args)
    encoder.to(device)
    torch.set_num_threads(1)

    config = {}
    config.update(vars(args))
    config['obs_space'] = observation_shape  # weird hack
    if args.method == 'cpc':
        trainer = CPCTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == 'spatial-appo':
        trainer = SpatioTemporalTrainer(encoder,
                                        config,
                                        device=device,
                                        wandb=wandb)
    elif args.method == 'vae':
        trainer = VAETrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "naff":
        trainer = NaFFPredictorTrainer(encoder,
                                       config,
                                       device=device,
                                       wandb=wandb)
    elif args.method == "infonce-stdim":
        trainer = InfoNCESpatioTemporalTrainer(encoder,
                                               config,
                                               device=device,
                                               wandb=wandb)
    elif args.method == "global-infonce-stdim":
        trainer = GlobalInfoNCESpatioTemporalTrainer(encoder,
                                                     config,
                                                     device=device,
                                                     wandb=wandb)
    elif args.method == "global-local-infonce-stdim":
        trainer = GlobalLocalInfoNCESpatioTemporalTrainer(encoder,
                                                          config,
                                                          device=device,
                                                          wandb=wandb)
    elif args.method == "dim":
        trainer = DIMTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "static-dim":
        trainer = StaticDIMTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "stdim-3":
        trainer = InfoNCESpatioTemporalTrainer3(encoder,
                                                config,
                                                device=device,
                                                wandb=wandb)
    elif args.method == "static-dim-2":
        trainer = StaticDIMTrainer2(encoder,
                                    config,
                                    device=device,
                                    wandb=wandb)
    else:
        assert False, "method {} has no trainer".format(args.method)

    trainer.train(pos_tr_eps, pos_val_eps, all_tr_episodes, all_val_episodes)

    return encoder