def main(): # create target output dir if it doesn't exist yet if not os.path.isdir(args.output_dir): os.mkdir(args.output_dir) # enable mixed-precision computation if desired amp = "" if args.amp: amp = "torch" if args.apex: print("Error: Cannot use both --amp and --apex.") exit() if args.apex: amp = "apex" mixed_precision.enable_mixed_precision() # set the RNG seeds (probably more hidden elsewhere...) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) # get the dataset dataset = get_dataset(args.dataset) encoder_size = get_encoder_size(dataset) # get a helper object for tensorboard logging log_dir = os.path.join(args.output_dir, args.run_name) stat_tracker = StatTracker(log_dir=log_dir) # get dataloaders for training and testing train_loader, test_loader, num_classes = \ build_dataset(dataset=dataset, batch_size=args.batch_size, input_dir=args.input_dir, labeled_only=args.classifiers) torch_device = torch.device('cuda') checkpointer = Checkpointer(args.output_dir) if args.cpt_load_path: model = checkpointer.restore_model_from_checkpoint( args.cpt_load_path, training_classifier=args.classifiers) else: # create new model with random parameters model = Model(ndf=args.ndf, n_classes=num_classes, n_rkhs=args.n_rkhs, tclip=args.tclip, n_depth=args.n_depth, encoder_size=encoder_size, use_bn=(args.use_bn == 1)) model.init_weights(init_scale=1.0) checkpointer.track_new_model(model) model = model.to(torch_device) # select which type of training to do task = train_classifiers if args.classifiers else train_self_supervised task(model, args.learning_rate, dataset, train_loader, test_loader, stat_tracker, checkpointer, args.output_dir, torch_device, amp)
def main(): # create target output dir if it doesn't exist yet if not os.path.isdir(args['output_dir']): os.mkdir(args['output_dir']) # enable mixed-precision computation if desired if args['amp']: mixed_precision.enable_mixed_precision() # set the RNG seeds (probably more hidden elsewhere...) torch.manual_seed(args['seed']) torch.cuda.manual_seed(args['seed']) # get the dataset dataset = get_dataset(args['dataset']) encoder_size = get_encoder_size(dataset) # get a helper object for tensorboard logging log_dir = os.path.join(args['output_dir'], args['run_name']) stat_tracker = StatTracker(log_dir=log_dir) # get dataloaders for training and testing train_loader, test_loader, num_classes = \ build_dataset(dataset=dataset, batch_size=args['batch_size'], input_dir=args['input_dir'], labeled_only=args['classifiers']) torch_device = torch.device('cuda') checkpointer = Checkpointer(args['output_dir']) if args['cpt_load_path']: model = checkpointer.restore_model_from_checkpoint( args['cpt_load_path'], training_classifier=args['classifiers']) else: # create new model with random parameters model = Model(ndf=args['ndf'], n_classes=num_classes, n_rkhs=args['n_rkhs'], tclip=args['tclip'], n_depth=args['n_depth'], encoder_size=encoder_size, use_bn=(args['use_bn'] == 1)) model.init_weights(init_scale=1.0) checkpointer.track_new_model(model) model = model.to(torch_device) # select which type of training to do task = train_classifiers if args['classifiers'] else train_self_supervised if args['classifiers']: task = train_classifiers elif args['decoder']: task = train_decoder else: task = train_self_supervised task(model, args['learning_rate'], dataset, train_loader, test_loader, stat_tracker, checkpointer, args['output_dir'], torch_device)