def run_probe(args): device = torch.device( "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu") collect_mode = args.collect_mode if args.method != 'pretrained-rl-agent' else "pretrained_representations" tr_eps, val_eps, tr_labels, val_labels, test_eps, test_labels = get_episodes( args, device, collect_mode=collect_mode, train_mode="probe") print("got episodes!") observation_shape = tr_eps[0][0].shape wandb.config.update(vars(args)) if args.train_encoder and args.method in train_encoder_methods: print("Training encoder from scratch") encoder = train_encoder(args) encoder.probing = True encoder.eval() else: if args.encoder_type == "Nature": encoder = NatureCNN(observation_shape[0], args) elif args.encoder_type == "Impala": encoder = ImpalaCNN(observation_shape[0], args) if args.weights_path == "None": if args.method not in probe_only_methods: sys.stderr.write( "Probing without loading in encoder weights! Are sure you want to do that??" ) else: print( "Print loading in encoder weights from probe of type {} from the following path: {}" .format(args.method, args.weights_path)) encoder.load_state_dict(torch.load(args.weights_path)) encoder.eval() torch.set_num_threads(1) if args.method == 'majority': test_acc, test_f1score = majority_baseline(tr_labels, test_labels, wandb) else: trainer = ProbeTrainer(encoder, wandb, epochs=args.epochs, sample_label=tr_labels[0][0], lr=args.probe_lr, batch_size=args.batch_size, device=device, patience=args.patience, log=False) trainer.train(tr_eps, val_eps, tr_labels, val_labels) test_acc, test_f1score = trainer.test(test_eps, test_labels) print(test_acc) print(test_f1score) wandb.log(test_acc) wandb.log(test_f1score)
def run_probe(args): wandb.config.update(vars(args)) tr_eps, val_eps, tr_labels, val_labels, test_eps, test_labels = get_episodes( steps=args.probe_steps, env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, downsample=not args.no_downsample, color=args.color, entropy_threshold=args.entropy_threshold, collect_mode=args.probe_collect_mode, train_mode="probe", checkpoint_index=args.checkpoint_index, min_episode_length=args.batch_size) print("got episodes!") if args.train_encoder and args.method in train_encoder_methods: print("Training encoder from scratch") encoder = train_encoder(args) encoder.probing = True encoder.eval() elif args.method in ["pretrained-rl-agent", "majority"]: encoder = None else: observation_shape = tr_eps[0][0].shape if args.encoder_type == "Nature": encoder = NatureCNN(observation_shape[0], args) elif args.encoder_type == "Impala": encoder = ImpalaCNN(observation_shape[0], args) if args.weights_path == "None": if args.method not in probe_only_methods: sys.stderr.write( "Probing without loading in encoder weights! Are sure you want to do that??" ) else: print( "Print loading in encoder weights from probe of type {} from the following path: {}" .format(args.method, args.weights_path)) encoder.load_state_dict(torch.load(args.weights_path)) encoder.eval() torch.set_num_threads(1) if args.method == 'majority': test_acc, test_f1score = majority_baseline(tr_labels, test_labels, wandb) else: trainer = ProbeTrainer(encoder=encoder, epochs=args.epochs, method_name=args.method, lr=args.probe_lr, batch_size=args.batch_size, patience=args.patience, wandb=wandb, fully_supervised=(args.method == "supervised"), save_dir=wandb.run.dir) trainer.train(tr_eps, val_eps, tr_labels, val_labels) test_acc, test_f1score = trainer.test(test_eps, test_labels) # trainer = SKLearnProbeTrainer(encoder=encoder) # test_acc, test_f1score = trainer.train_test(tr_eps, val_eps, tr_labels, val_labels, # test_eps, test_labels) print(test_acc, test_f1score) wandb.log(test_acc) wandb.log(test_f1score)
from src.dim_baseline import DIMTrainer from src.encoders import ImpalaCNN, NatureCNN, NatureOneCNN from src.global_infonce_stdim import GlobalInfoNCESpatioTemporalTrainer from src.global_local_infonce import GlobalLocalInfoNCESpatioTemporalTrainer from src.infonce_spatio_temporal import InfoNCESpatioTemporalTrainer from src.no_action_feedforward_predictor import NaFFPredictorTrainer from src.spatio_temporal import SpatioTemporalTrainer from src.utils import get_argparser from src.vae import VAETrainer import torch import wandb if __name__ == "__main__": parser = get_argparser() print("1") args = parser.parse_args() print("2") tags = ["pretraining-only"] print("3") wandb.init(project=args.wandb_proj, entity="curl-atari", tags=tags) print("4") config = {} print("5") config.update(vars(args)) print("6") wandb.config.update(config) print("7") index_array = torch.randperm(823) print("8") train_encoder(args, index_array)