def train_encoder(args): device = torch.device( "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu") tr_eps, val_eps = get_episodes(args, device, collect_mode=args.collect_mode, train_mode="train_encoder", seed=args.seed) observation_shape = tr_eps[0][0].shape if args.encoder_type == "Nature": encoder = NatureCNN(observation_shape[0], args) elif args.encoder_type == "Impala": encoder = ImpalaCNN(observation_shape[0], args) encoder.to(device) torch.set_num_threads(1) config = {} config.update(vars(args)) config['obs_space'] = observation_shape # weird hack if args.method == 'cpc': trainer = CPCTrainer(encoder, config, device=device, wandb=wandb) elif args.method == 'spatial-appo': trainer = SpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == 'vae': trainer = VAETrainer(encoder, config, device=device, wandb=wandb) elif args.method == "naff": trainer = NaFFPredictorTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "infonce-stdim": trainer = InfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "global-infonce-stdim": trainer = GlobalInfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "global-local-infonce-stdim": trainer = GlobalLocalInfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) else: assert False, "method {} has no trainer".format(args.method) trainer.train(tr_eps, val_eps) return encoder
def train_encoder(args): device = torch.device( "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu") tr_eps, val_eps = get_episodes(steps=args.pretraining_steps, env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, downsample=not args.no_downsample, color=args.color, entropy_threshold=args.entropy_threshold, collect_mode=args.probe_collect_mode, train_mode="train_encoder", checkpoint_index=args.checkpoint_index, min_episode_length=args.batch_size) observation_shape = tr_eps[0][0].shape if args.encoder_type == "Nature": encoder = NatureCNN(observation_shape[0], args) elif args.encoder_type == "Impala": encoder = ImpalaCNN(observation_shape[0], args) encoder.to(device) torch.set_num_threads(1) config = {} config.update(vars(args)) config['obs_space'] = observation_shape # weird hack if args.method == 'cpc': trainer = CPCTrainer(encoder, config, device=device, wandb=wandb) elif args.method == 'spatial-appo': trainer = SpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == 'vae': trainer = VAETrainer(encoder, config, device=device, wandb=wandb) elif args.method == "naff": trainer = NaFFPredictorTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "infonce-stdim": trainer = InfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "global-infonce-stdim": trainer = GlobalInfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "global-local-infonce-stdim": trainer = GlobalLocalInfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "dim": trainer = DIMTrainer(encoder, config, device=device, wandb=wandb) else: assert False, "method {} has no trainer".format(args.method) trainer.train(tr_eps, val_eps) return encoder
def run_probe(args): wandb.config.update(vars(args)) tr_eps, val_eps, tr_labels, val_labels, test_eps, test_labels = get_episodes(steps=args.probe_steps, env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, downsample=not args.no_downsample, color=args.color, entropy_threshold=args.entropy_threshold, collect_mode=args.probe_collect_mode, train_mode="probe", checkpoint_index=args.checkpoint_index, min_episode_length=args.batch_size) print("got episodes!") if args.train_encoder and args.method in train_encoder_methods: print("Training encoder from scratch") encoder = train_encoder(args) encoder.probing = True encoder.eval() elif args.method in ["pretrained-rl-agent", "majority"]: encoder = None else: observation_shape = tr_eps[0][0].shape if args.encoder_type == "Nature": encoder = NatureCNN(observation_shape[0], args) elif args.encoder_type == "Impala": encoder = ImpalaCNN(observation_shape[0], args) if args.weights_path == "None": if args.method not in probe_only_methods: sys.stderr.write("Probing without loading in encoder weights! Are sure you want to do that??") else: print("Print loading in encoder weights from probe of type {} from the following path: {}" .format(args.method, args.weights_path)) encoder.load_state_dict(torch.load(args.weights_path)) encoder.eval() torch.set_num_threads(1) if args.method == 'majority': test_acc, test_f1score = majority_baseline(tr_labels, test_labels, wandb) else: trainer = ProbeTrainer(encoder=encoder, epochs=args.epochs, method_name=args.method, lr=args.probe_lr, batch_size=args.batch_size, patience=args.patience, wandb=wandb, fully_supervised=(args.method == "supervised"), save_dir=wandb.run.dir) trainer.train(tr_eps, val_eps, tr_labels, val_labels) test_acc, test_f1score = trainer.test(test_eps, test_labels) print(test_acc, test_f1score) wandb.log(test_acc) wandb.log(test_f1score)
def train_encoder(args): device = torch.device( "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu") envs = [ 'AsteroidsNoFrameskip-v4', 'BerzerkNoFrameskip-v4', 'BowlingNoFrameskip-v4', 'BoxingNoFrameskip-v4', 'BreakoutNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', 'FrostbiteNoFrameskip-v4', 'HeroNoFrameskip-v4', 'MontezumaRevengeNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'PitfallNoFrameskip-v4', 'PongNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'QbertNoFrameskip-v4', 'RiverraidNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'TennisNoFrameskip-v4', 'VentureNoFrameskip-v4', 'VideoPinballNoFrameskip-v4', 'YarsRevengeNoFrameskip-v4' ] all_tr_episodes, all_val_episodes = [], [] pos_tr_eps, pos_val_eps = [], [] for env in envs: steps = 5000 if env == args.env_name: steps = 50000 tr_eps, val_eps = get_episodes( steps=steps, env_name=env, seed=args.seed, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, downsample=not args.no_downsample, color=args.color, entropy_threshold=args.entropy_threshold, collect_mode=args.probe_collect_mode, train_mode="train_encoder", checkpoint_index=args.checkpoint_index, min_episode_length=args.batch_size) if env == args.env_name: pos_tr_eps = tr_eps pos_val_eps = val_eps else: all_tr_episodes.append(tr_eps) all_val_episodes.append(val_eps) all_tr_episodes = list(chain.from_iterable(all_tr_episodes)) all_val_episodes = list(chain.from_iterable(all_val_episodes)) observation_shape = pos_tr_eps[0][0].shape if args.encoder_type == "Nature": encoder = NatureCNN(observation_shape[0], args) elif args.encoder_type == "Impala": encoder = ImpalaCNN(observation_shape[0], args) encoder.to(device) torch.set_num_threads(1) config = {} config.update(vars(args)) config['obs_space'] = observation_shape # weird hack if args.method == 'cpc': trainer = CPCTrainer(encoder, config, device=device, wandb=wandb) elif args.method == 'spatial-appo': trainer = SpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == 'vae': trainer = VAETrainer(encoder, config, device=device, wandb=wandb) elif args.method == "naff": trainer = NaFFPredictorTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "infonce-stdim": trainer = InfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "global-infonce-stdim": trainer = GlobalInfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "global-local-infonce-stdim": trainer = GlobalLocalInfoNCESpatioTemporalTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "dim": trainer = DIMTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "static-dim": trainer = StaticDIMTrainer(encoder, config, device=device, wandb=wandb) elif args.method == "stdim-3": trainer = InfoNCESpatioTemporalTrainer3(encoder, config, device=device, wandb=wandb) elif args.method == "static-dim-2": trainer = StaticDIMTrainer2(encoder, config, device=device, wandb=wandb) else: assert False, "method {} has no trainer".format(args.method) trainer.train(pos_tr_eps, pos_val_eps, all_tr_episodes, all_val_episodes) return encoder