def train_encoder(args):
    device = torch.device("cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu")
    tr_eps, val_eps = get_episodes(steps=args.pretraining_steps,
                                 env_name=args.env_name,
                                 seed=args.seed,
                                 num_processes=args.num_processes,
                                 num_frame_stack=args.num_frame_stack,
                                 downsample=not args.no_downsample,
                                 color=args.color,
                                 entropy_threshold=args.entropy_threshold,
                                 collect_mode=args.probe_collect_mode,
                                 train_mode="train_encoder",
                                 checkpoint_index=args.checkpoint_index,
                                 min_episode_length=args.batch_size)

    observation_shape = tr_eps[0][0].shape
    if args.encoder_type == "Nature":  # NCHW
        encoder = NatureCNN(observation_shape[0], args)
    elif args.encoder_type == "Impala":
        encoder = ImpalaCNN(observation_shape[0], args)
    encoder.to(device)
    torch.set_num_threads(1)

    config = {}
    config.update(vars(args))
    config['obs_space'] = observation_shape  # weird hack
    if args.method == 'cpc':
        trainer = CPCTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == 'jsd-stdim':
        trainer = SpatioTemporalTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == 'vae':
        trainer = VAETrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "naff":
        trainer = NaFFPredictorTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "infonce-stdim":
        trainer = InfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "global-infonce-stdim":
        trainer = GlobalInfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "global-local-infonce-stdim":
        trainer = GlobalLocalInfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "dim":
        trainer = DIMTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == "ae":
        trainer = AETrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == 'ib':
        trainer = InfoBottleneck(encoder, config, device=device, wandb=wandb)
    elif args.method == 'global-t-dim':
        trainer = GlobalTemporalDIMTrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == 'ib-ae-nce':
        trainer = IBAENCETrainer(encoder, config, device=device, wandb=wandb)
    elif args.method == 'ib-stdim':
        trainer = IBSTDIMTrainer(encoder, config, device=device, wandb=wandb)
    else:
        assert False, "method {} has no trainer".format(args.method)

    trainer.train(tr_eps, val_eps)

    return encoder
def train_encoder(args, p=None):
    device = torch.device(
        "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu")
    tr_eps, val_eps = get_episodes(steps=args.pretraining_steps,
                                   env_name=args.env_name,
                                   seed=args.seed,
                                   num_processes=1,
                                   num_frame_stack=args.num_frame_stack,
                                   downsample=not args.no_downsample,
                                   color=args.color,
                                   entropy_threshold=args.entropy_threshold,
                                   collect_mode=args.probe_collect_mode,
                                   train_mode="train_encoder",
                                   checkpoint_index=args.checkpoint_index,
                                   min_episode_length=args.batch_size)

    observation_shape = tr_eps[0][0].shape
    if args.encoder_type == "Nature":
        encoder = NatureCNN(observation_shape[0], args)
    elif args.encoder_type == "Impala":
        encoder = ImpalaCNN(observation_shape[0], args)
    encoder.to(device)
    torch.set_num_threads(1)

    config = {}
    config.update(vars(args))
    config['obs_space'] = observation_shape
    if p:
        trainer = InfoNCESpatioTemporalTrainerExt(encoder,
                                                  config,
                                                  device=device,
                                                  wandb=wandb,
                                                  p=p)
    else:
        trainer = InfoNCESpatioTemporalTrainerExt(encoder,
                                                  config,
                                                  device=device,
                                                  wandb=wandb)

    trainer.train(tr_eps, val_eps)

    return encoder
def run_probe(args):
    wandb.config.update(vars(args))
    tr_eps, val_eps, tr_labels, val_labels, test_eps, test_labels = get_episodes(
        steps=args.probe_steps,
        env_name=args.env_name,
        seed=args.seed,
        num_processes=args.num_processes,
        num_frame_stack=args.num_frame_stack,
        downsample=not args.no_downsample,
        color=args.color,
        entropy_threshold=args.entropy_threshold,
        collect_mode=args.probe_collect_mode,
        train_mode="probe",
        checkpoint_index=args.checkpoint_index,
        min_episode_length=args.batch_size)

    print("got episodes!")

    if args.train_encoder and args.method in train_encoder_methods:
        print("Training encoder from scratch")
        encoder = train_encoder(args)
        encoder.probing = True
        encoder.eval()

    elif args.method in ["pretrained-rl-agent", "majority"]:
        encoder = None

    else:
        observation_shape = tr_eps[0][0].shape
        if args.encoder_type == "Nature":
            encoder = NatureCNN(observation_shape[0], args)
        elif args.encoder_type == "Impala":
            encoder = ImpalaCNN(observation_shape[0], args)

        if args.weights_path == "None":
            if args.method not in probe_only_methods:
                sys.stderr.write(
                    "Probing without loading in encoder weights! Are sure you want to do that??"
                )
        else:
            print(
                "Print loading in encoder weights from probe of type {} from the following path: {}"
                .format(args.method, args.weights_path))
            encoder.load_state_dict(torch.load(args.weights_path))
            encoder.eval()

    torch.set_num_threads(1)

    if args.method == 'majority':
        test_acc, test_f1score = majority_baseline(tr_labels, test_labels,
                                                   wandb)

    else:
        trainer = ProbeTrainer(encoder=encoder,
                               epochs=args.epochs,
                               method_name=args.method,
                               lr=args.probe_lr,
                               batch_size=args.batch_size,
                               patience=args.patience,
                               wandb=wandb,
                               fully_supervised=(args.method == "supervised"),
                               save_dir=wandb.run.dir)

        trainer.train(tr_eps, val_eps, tr_labels, val_labels)
        test_acc, test_f1score = trainer.test(test_eps, test_labels)
        # trainer = SKLearnProbeTrainer(encoder=encoder)
        # test_acc, test_f1score = trainer.train_test(tr_eps, val_eps, tr_labels, val_labels,
        #                                             test_eps, test_labels)

    print(test_acc, test_f1score)
    wandb.log(test_acc)
    wandb.log(test_f1score)
Example #4
0
parser.add_argument('--entity', default='neurips-challenge', dest='entity')
parser.add_argument('--project', default='atari-ari', dest='project')
parser.add_argument('--receptive-field',
                    default=16,
                    type=int,
                    dest='receptive_field')
args = parser.parse_args(sys.argv[1:])

device = torch.device(
    "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu")
tr_eps, val_eps = get_episodes(steps=args.pretraining_steps,
                               env_name=args.env_name,
                               seed=args.seed,
                               num_processes=1,
                               num_frame_stack=args.num_frame_stack,
                               downsample=not args.no_downsample,
                               color=args.color,
                               entropy_threshold=args.entropy_threshold,
                               collect_mode=args.probe_collect_mode,
                               train_mode="train_encoder",
                               checkpoint_index=args.checkpoint_index,
                               min_episode_length=args.batch_size)

observation_shape = tr_eps[0][0].shape
if args.encoder_type == "Nature":
    encoder = AdjustableNatureCNN(observation_shape[0], args)
elif args.encoder_type == "Impala":
    encoder = ImpalaCNN(observation_shape[0], args)

encoder.to(device)
torch.set_num_threads(1)