def load(folder, rng_seed, model, optimizer, architect=None, s3_bucket=None): # Try to download log and ckpt from s3 first to see if a ckpt exists. ckpt = os.path.join(folder, "model.ckpt") history_file = os.path.join(folder, "history.pkl") history = {} if s3_bucket is not None: aws_utils.download_from_s3(ckpt, s3_bucket, ckpt) try: aws_utils.download_from_s3(history_file, s3_bucket, history_file) except: logging.info("history.pkl not in s3 bucket") if os.path.exists(history_file): with open(history_file, "rb") as f: history = pickle.load(f) # TODO: update architecture history architect.load_history(history) checkpoint = torch.load(ckpt) epochs = checkpoint["epochs"] rng_seed.load_states(checkpoint["rng_seed"]) model.load_states(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) if architect is not None: architect.load_states(checkpoint["architect"]) logging.info("Resumed model trained for %d epochs" % epochs) return epochs, history
def load(folder, rng_seed, model, optimizer, s3_bucket=None): # Try to download log and ckpt from s3 first to see if a ckpt exists. ckpt = os.path.join(folder, "model.ckpt") history_file = os.path.join(folder, "history.pkl") history = None if s3_bucket is not None: aws_utils.download_from_s3(ckpt, s3_bucket, ckpt) try: aws_utils.download_from_s3(history_file, s3_bucket, history_file) except: logging.info("history.pkl not in s3 bucket") if os.path.exists(history_file): with open(history_file, "rb") as f: history = pickle.load(f) checkpoint = torch.load(ckpt) epochs = checkpoint["epochs"] rng_seed.load_states(checkpoint["rng_seed"]) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) module = model.module params = [ module.alphas_normal, module.alphas_reduce, module.betas_normal, module.betas_reduce, ] for p, s in zip(params, checkpoint['arch_params']): p = s logging.info("Resumed model trained for %d epochs" % epochs) return epochs, history
def load(folder, rng_seed, model, optimizer, architect=None, s3_bucket=None, best_eval=False, gpu=None): """Loads checkpoint Args: folder (str): Directory that contains the checkpoint which should be loaded. rng_seed (RNGSeed): Random seed object that should get initialized from the checkpoint. Reference is modified. model: Model that should get initialized from the checkpoint. Reference is modified. optimizer: Optimizer that should get initialized from the checkpoint. Reference is modified. architect: Architectut that should get initialized from the checkpoint. Reference is modified. s3_bucket: AWS stuff. Unused. best_eval (bool): Whether the best checkpoint from evaluation should be loaded. This will search for a checkpoint called 'model_best.ckpt' instead of 'model.ckpt'. If no such file is found, an error is raised. gpu (int or None): For multi-process loading specifies to which gpu the memory should be mapped. Returns: int: Epochs history int: Current overall runtime of the model dict: Properties of the currently best observed genotype float: Peak GPU memory allocated by PyTorch in MB. float: Peak GPU memory reserved by PyTorch in MB. """ # Try to download log and ckpt from s3 first to see if a ckpt exists. ckpt = os.path.join(folder, "model_best.ckpt") if best_eval else os.path.join( folder, "model.ckpt") if not os.path.isfile(ckpt): raise ValueError(f"No valid checkpoint file found: {ckpt}") history_file = os.path.join(folder, "history.pkl") history = {} if s3_bucket is not None: aws_utils.download_from_s3(ckpt, s3_bucket, ckpt) try: aws_utils.download_from_s3(history_file, s3_bucket, history_file) except: logging.info("history.pkl not in s3 bucket") if os.path.exists(history_file) and architect is not None: with open(history_file, "rb") as f: history = pickle.load(f) # TODO: update architecture history architect.load_history(history) if gpu is not None: map_location = {'cuda:0': f'cuda:{gpu}'} #map_location = torch.device(f'cuda:{gpu}') checkpoint = torch.load(ckpt) if gpu is None else torch.load( ckpt, map_location=map_location) epochs = checkpoint["epochs"] rng_seed.load_states(checkpoint["rng_seed"]) if type(model) == DistributedDataParallel: model.module.load_states(checkpoint["model"]) else: model.load_states(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) if "runtime" in checkpoint.keys(): runtime = checkpoint["runtime"] else: runtime = 0 if "best_observed" in checkpoint.keys(): best_observed = checkpoint["best_observed"] best_observed["genotype_raw"] = dict_to_genotype( best_observed["genotype_dict"]) else: best_observed = None if architect is not None: architect.load_states(checkpoint["architect"]) max_mem_allocated_MB = checkpoint[ 'max_mem_allocated_MB'] if "max_mem_allocated_MB" in checkpoint.keys( ) else 0. max_mem_reserved_MB = checkpoint[ 'max_mem_reserved_MB'] if 'max_mem_reserved_MB' in checkpoint.keys( ) else 0. #logging.info(f"Resumed model trained for {epochs} epochs") #logging.info(f"Resumed model trained for {timedelta(seconds=runtime)} hh:mm:ss") return epochs, history, runtime, best_observed, max_mem_allocated_MB, max_mem_reserved_MB
def main(args): s3_bucket = args.run.s3_bucket log = os.path.join(os.getcwd(), "log.txt") if s3_bucket is not None: aws_utils.download_from_s3(log, s3_bucket, log) train_utils.set_up_logging(log) CIFAR_CLASSES = 10 if not torch.cuda.is_available(): logging.info("no gpu device available") sys.exit(1) try: aws_utils.download_from_s3("cnn_genotypes.txt", s3_bucket, "/tmp/cnn_genotypes.txt") with open("/code/nas-theory/cnn/search_spaces/darts/genotypes.py", "a") as f: with open("/tmp/cnn_genotypes.txt") as archs: f.write("\n") for line in archs: if "Genotype" in line: f.write(line) print("Downloaded genotypes from aws.") except Exception as e: print(e) # Importing here because need to get the latest genotypes before importing. from search_spaces.darts import genotypes rng_seed = train_utils.RNGSeed(args.run.seed) torch.cuda.set_device(args.run.gpu) logging.info("gpu device = %d" % args.run.gpu) logging.info("args = %s", args.pretty()) print(dir(genotypes)) genotype = eval("genotypes.%s" % args.train.arch) criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() model = Network( args.train.init_channels, CIFAR_CLASSES, args.train.layers, args.train.auxiliary, genotype, ) model = model.cuda() optimizer, scheduler = train_utils.setup_optimizer(model, args) logging.info("param size = %fMB", train_utils.count_parameters_in_MB(model)) total_params = sum(x.data.nelement() for x in model.parameters()) logging.info("Model total parameters: {}".format(total_params)) try: start_epochs, _ = train_utils.load(os.getcwd(), rng_seed, model, optimizer, s3_bucket=s3_bucket) scheduler.last_epoch = start_epochs - 1 except Exception as e: print(e) start_epochs = 0 num_train, num_classes, train_queue, valid_queue = train_utils.create_data_queues( args, eval_split=True) for epoch in range(start_epochs, args.run.epochs): logging.info("epoch %d lr %e", epoch, scheduler.get_lr()[0]) model.drop_path_prob = args.train.drop_path_prob * epoch / args.run.epochs train_acc, train_obj = train(args, train_queue, model, criterion, optimizer) logging.info("train_acc %f", train_acc) valid_acc, valid_obj = train_utils.infer( valid_queue, model, criterion, report_freq=args.run.report_freq) logging.info("valid_acc %f", valid_acc) train_utils.save(os.getcwd(), epoch + 1, rng_seed, model, optimizer, s3_bucket=s3_bucket) scheduler.step()
def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) log = os.path.join(args.save, 'log.txt') if args.s3_bucket is not None: aws_utils.download_from_s3(log, args.s3_bucket, log) rng_seed = train_utils.RNGSeed(args.seed, deterministic=False) logging.info("args = %s", args) #dataset_dir = '/cache/' #pre.split_dataset(dataset_dir) #sys.exit(1) # dataset prepare data_dir = os.path.join(args.tmp_data_dir, 'imagenet_search') traindir = os.path.join(data_dir, 'train') valdir = data_dir = os.path.join(data_dir, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() #dataset split train_data1 = dset.ImageFolder(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_data2 = dset.ImageFolder(valdir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) valid_data = dset.ImageFolder(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) num_train = len(train_data1) num_val = len(train_data2) print('# images to train network: %d' % num_train) print('# images to validate network: %d' % num_val) model = Network(args.init_channels, CLASSES, args.layers, criterion) model = torch.nn.DataParallel(model) model = model.cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) optimizer = torch.optim.SGD( model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) architect = Architect(model, criterion, args) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs), eta_min=args.learning_rate_min) history = [] try: start_epochs, history = train_utils.load( args.save, rng_seed, model, optimizer, args.s3_bucket ) print(history) scheduler.last_epoch = start_epochs - 1 except Exception as e: print(e) start_epochs = 0 test_queue = torch.utils.data.DataLoader( valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) train_queue = torch.utils.data.DataLoader( train_data1, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) valid_queue = torch.utils.data.DataLoader( train_data2, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) if args.debug: train_queue = valid_queue valid_queue = test_queue lr=args.learning_rate for epoch in range(start_epochs, args.epochs): scheduler.step() current_lr = scheduler.get_lr()[0] logging.info('Epoch: %d lr: %e', epoch, current_lr) if epoch < 5 and args.batch_size > 256: for param_group in optimizer.param_groups: param_group['lr'] = lr * (epoch + 1) / 5.0 logging.info('Warming-up Epoch: %d, LR: %e', epoch, lr * (epoch + 1) / 5.0) print(optimizer) genotype = model.module.genotype() logging.info('genotype = %s', genotype) arch_param = model.module.arch_parameters() logging.info(arch_param[0]) logging.info(arch_param[1]) logging.info(arch_param[2]) logging.info(arch_param[3]) # training train_acc, train_obj = train(train_queue, valid_queue, model, optimizer, architect, criterion, lr,epoch) logging.info('Train_acc %f', train_acc) # validation if epoch>= 47: valid_acc, valid_obj = infer(valid_queue, model, criterion) #test_acc, test_obj = infer(test_queue, model, criterion) logging.info('Valid_acc %f', valid_acc) #logging.info('Test_acc %f', test_acc) history.append([p.data.cpu().numpy() for p in model.module.arch_parameters()]) logging.info("saving checkpoint") train_utils.save(args.save, epoch+1, rng_seed, model, optimizer, history, args.s3_bucket) if args.s3_bucket is not None: filename = "cnn_genotypes.txt" aws_utils.download_from_s3(filename, args.s3_bucket, filename) with open(filename, "a+") as f: f.write("\n") f.write("{}{} = {}".format('edarts', args.seed, genotype)) aws_utils.upload_to_s3(filename, args.s3_bucket, filename) aws_utils.upload_to_s3(log, args.s3_bucket, log)
def main(args): """Performs NAS. """ np.set_printoptions(precision=3) save_dir = os.getcwd() log = os.path.join(save_dir, "log.txt") # Setup SummaryWriter summary_dir = os.path.join(save_dir, "summary") if not os.path.exists(summary_dir): os.mkdir(summary_dir) writer = SummaryWriter(summary_dir) # own writer that I use to keep track of interesting variables own_writer = SummaryWriter(os.path.join(save_dir, 'tensorboard')) if args.run.s3_bucket is not None: aws_utils.download_from_s3(log, args.run.s3_bucket, log) train_utils.copy_code_to_experiment_dir( "/home/julienf/git/gaea_release/cnn", save_dir) aws_utils.upload_directory(os.path.join(save_dir, "scripts"), args.run.s3_bucket) train_utils.set_up_logging(log) if not torch.cuda.is_available(): logging.info("no gpu device available") sys.exit(1) torch.cuda.set_device(args.run.gpu) logging.info("Search hyperparameters:") #logging.info("gpu device = %d" % args.run.gpu) logging.info(args.pretty()) # Set random seeds for random, numpy, torch and cuda rng_seed = train_utils.RNGSeed(args.run.seed) # Load respective architect if args.search.method in ["edarts", "gdarts", "eedarts"]: if args.search.fix_alphas: from architect.architect_edarts_edge_only import ( ArchitectEDARTS as Architect, ) else: from architect.architect_edarts import ArchitectEDARTS as Architect elif args.search.method in ["darts", "fdarts"]: from architect.architect_darts import ArchitectDARTS as Architect elif args.search.method == "egdas": from architect.architect_egdas import ArchitectEGDAS as Architect else: raise NotImplementedError # Load respective search spaces if args.search.search_space in ["darts", "darts_small"]: from search_spaces.darts.model_search import DARTSNetwork as Network elif "nas-bench-201" in args.search.search_space: from search_spaces.nasbench_201.model_search import ( NASBENCH201Network as Network, ) elif args.search.search_space == "pcdarts": from search_spaces.pc_darts.model_search import PCDARTSNetwork as Network else: raise NotImplementedError if args.train.smooth_cross_entropy: criterion = train_utils.cross_entropy_with_label_smoothing else: criterion = nn.CrossEntropyLoss() #num_train, num_classes, train_queue, valid_queue = train_utils.create_data_queues( # args #) num_classes, (train_queue, train_2_queue), valid_queue, test_queue, ( number_train, number_valid, number_test) = train_utils.create_cifar10_data_queues_own(args) logging.info(f"Dataset: {args.run.dataset}, num_classes: {num_classes}") logging.info(f"Number of training images: {number_train}") if args.search.single_level: logging.info( f"Number of validation images (unused during search): {number_valid}" ) else: logging.info( f"Number of validation images (used during search): {number_valid}" ) logging.info( f"Number of test images (unused during search): {number_test}") model = Network( args.train.init_channels, num_classes, args.search.nodes, args.train.layers, criterion, **{ "auxiliary": args.train.auxiliary, "search_space_name": args.search.search_space, "exclude_zero": args.search.exclude_zero, "track_running_stats": args.search.track_running_stats, }) #if args.run.dataset == 'cifar10': # random_img = np.random.randint(0, 255, size=(1, 3, 32, 32)) # own_writer.add_graph(model, input_to_model=torch.from_numpy(random_img)) model = model.cuda() logging.info("param size = %fMB", train_utils.count_parameters_in_MB(model)) optimizer, scheduler = train_utils.setup_optimizer(model, args) # TODO: separate args by model, architect, etc # TODO: look into using hydra for config files architect = Architect(model, args, writer) # Try to load a previous checkpoint try: start_epochs, history, _, _ = train_utils.load(save_dir, rng_seed, model, optimizer, architect, args.run.s3_bucket) scheduler.last_epoch = start_epochs - 1 #( # num_train, # num_classes, # train_queue, # valid_queue, #) = train_utils.create_data_queues(args) # TODO: why are data queues reloaded? num_classes, (train_queue, train_2_queue), valid_queue, test_queue, ( number_train, number_valid, number_test) = train_utils.create_cifar10_data_queues_own(args) logging.info( 'Resumed training from a previous checkpoint. Runtime measurement will be wrong.' ) train_start_time = 0 except Exception as e: logging.info(e) start_epochs = 0 train_start_time = timer() best_valid = 0 # for single-level search, corresponds to best train accuracy observed so far epoch_best_valid = 0 # for single-level search, corresponds to the epoch of the best observed train accuracy so far overall_visualization_time = 0 # don't count visualization into runtime for epoch in range(start_epochs, args.run.epochs): lr = scheduler.get_lr()[0] logging.info(f"| Epoch: {epoch:3d} / {args.run.epochs} | lr: {lr} |") model.drop_path_prob = args.train.drop_path_prob * epoch / args.run.epochs # training returns top1, loss and top5 train_acc, train_obj, train_top5 = train( args, train_queue, valid_queue if train_2_queue == None else train_2_queue, # valid_queue for bi-level search, train_2_queue for single-level search model, architect, criterion, optimizer, lr, ) architect.baseline = train_obj architect.update_history() architect.log_vars(epoch, writer) if "update_lr_state" in dir(scheduler): scheduler.update_lr_state(train_obj) logging.info(f"| train_acc: {train_acc} |") # History tracking for vs in [("alphas", architect.alphas), ("edges", architect.edges)]: for ct in vs[1]: v = vs[1][ct] logging.info("{}-{}".format(vs[0], ct)) logging.info(v) # Calling genotypes sets alphas to best arch for EGDAS and MEGDAS # so calling here before infer. genotype = architect.genotype() logging.info("genotype = %s", genotype) # log epoch values to tensorboard own_writer.add_scalar('Loss/train', train_obj, epoch) own_writer.add_scalar('Top1/train', train_acc, epoch) own_writer.add_scalar('Top5/train', train_top5, epoch) own_writer.add_scalar('lr', lr, epoch) # visualize Genotype start_visualization = timer() genotype_graph_normal = visualize.plot(genotype.normal, "", return_type="graph", output_format='png') binary_normal = genotype_graph_normal.pipe() stream_normal = io.BytesIO(binary_normal) graph_normal = np.array(PIL.Image.open(stream_normal).convert("RGB")) own_writer.add_image("Normal_Cell", graph_normal, epoch, dataformats="HWC") #del genotype_graph_normal #del binary_normal #del stream_normal #del graph_normal genotype_graph_reduce = visualize.plot(genotype.reduce, "", return_type="graph", output_format='png') binary_reduce = genotype_graph_reduce.pipe() stream_reduce = io.BytesIO(binary_reduce) graph_reduce = np.array(PIL.Image.open(stream_reduce).convert("RGB")) own_writer.add_image("Reduce_Cell", graph_reduce, epoch, dataformats="HWC") #del genotype_graph_reduce #del binary_reduce #del stream_reduce #del graph_reduce end_visualization = timer() overall_visualization_time += (end_visualization - start_visualization) # log validation metrics, but don't utilize them for decisions during single-level search valid_acc, valid_obj, valid_top5 = train_utils.infer( valid_queue, model, criterion, report_freq=args.run.report_freq, discrete=args.search.discrete, ) own_writer.add_scalar('Loss/valid', valid_obj, epoch) own_writer.add_scalar('Top1/valid', valid_acc, epoch) own_writer.add_scalar('Top5/valid', valid_top5, epoch) logging.info(f"| valid_acc: {valid_acc} |") if not args.search.single_level: if valid_acc > best_valid: best_valid = valid_acc best_genotype = architect.genotype() epoch_best_valid = epoch else: if train_acc > best_valid: best_valid = train_acc best_genotype = architect.genotype() epoch_best_valid = epoch train_utils.save( save_dir, epoch + 1, rng_seed, model, optimizer, architect, save_history=True, s3_bucket=args.run.s3_bucket, ) scheduler.step() train_end_time = timer() logging.info( f"Visualization of cells during search took a total of {timedelta(seconds=overall_visualization_time)} (hh:mm:ss)." ) logging.info(f"This time is not included in the runtime given below.\n") logging.info( f"Training finished after {timedelta(seconds=((train_end_time - train_start_time) - overall_visualization_time))}(hh:mm:ss)." ) # Performing validation of final epoch...") #valid_acc, valid_obj, valid_top5 = train_utils.infer( # valid_queue, # model, # criterion, # report_freq=args.run.report_freq, # discrete=args.search.discrete, #) #own_writer.add_scalar('Loss/valid', valid_obj, args.run.epochs-1) #own_writer.add_scalar('Top1/valid', valid_acc, args.run.epochs-1) #own_writer.add_scalar('Top5/valid', valid_top5, args.run.epochs-1) #logging.info(f"| valid_acc: {valid_acc} |") #if not args.search.single_level: # if valid_acc > best_valid: # best_valid = valid_acc # best_genotype = architect.genotype() # epoch_best_valid = args.run.epochs-1 #else: # if train_acc > best_valid: # best_valid = train_acc # best_genotype = architect.genotype() # epoch_best_valid = args.run.epochs-1 if args.search.single_level: logging.info(( f"\nBecause single-level search is performed, the best genotype was not selected according to the best achieved validation accuracy " f"but according to the best train accuracy.")) logging.info( f"\nOverall best found genotype with validation accuracy of {best_valid} (found in epoch {epoch_best_valid}):" ) logging.info(f"{best_genotype}") # dump best genotype to json file, so that we can load it during evaluation phase (in train_final.py) genotype_dict = best_genotype._asdict() for key, val in genotype_dict.items(): if type(val) == range: genotype_dict[key] = [node for node in val] if os.path.splitext(args.run.genotype_path)[1] != '.json': args.run.genotype_path += '.json' with open(args.run.genotype_path, 'w') as genotype_file: json.dump(genotype_dict, genotype_file, indent=4) logging.info( f"Search finished. Dumped best genotype into {args.run.genotype_path}") if args.run.s3_bucket is not None: filename = "cnn_genotypes.txt" aws_utils.download_from_s3(filename, args.run.s3_bucket, filename) with open(filename, "a+") as f: f.write("\n") f.write("{}{}{}{} = {}".format( args.search.search_space, args.search.method, args.run.dataset.replace("-", ""), args.run.seed, best_genotype, )) aws_utils.upload_to_s3(filename, args.run.s3_bucket, filename) aws_utils.upload_to_s3(log, args.run.s3_bucket, log)
def main(args): np.set_printoptions(precision=3) save_dir = os.getcwd() log = os.path.join(save_dir, "log.txt") # Setup SummaryWriter summary_dir = os.path.join(save_dir, "summary") if not os.path.exists(summary_dir): os.mkdir(summary_dir) writer = SummaryWriter(summary_dir) if args.run.s3_bucket is not None: aws_utils.download_from_s3(log, args.run.s3_bucket, log) train_utils.copy_code_to_experiment_dir("/code/nas-theory/cnn", save_dir) aws_utils.upload_directory(os.path.join(save_dir, "scripts"), args.run.s3_bucket) train_utils.set_up_logging(log) if not torch.cuda.is_available(): logging.info("no gpu device available") sys.exit(1) torch.cuda.set_device(args.run.gpu) logging.info("gpu device = %d" % args.run.gpu) logging.info("args = %s", args.pretty()) rng_seed = train_utils.RNGSeed(args.run.seed) if args.search.method in ["edarts", "gdarts", "eedarts"]: if args.search.fix_alphas: from architect.architect_edarts_edge_only import ( ArchitectEDARTS as Architect, ) else: from architect.architect_edarts import ArchitectEDARTS as Architect elif args.search.method in ["darts", "fdarts"]: from architect.architect_darts import ArchitectDARTS as Architect elif args.search.method == "egdas": from architect.architect_egdas import ArchitectEGDAS as Architect else: raise NotImplementedError if args.search.search_space in ["darts", "darts_small"]: from search_spaces.darts.model_search import DARTSNetwork as Network elif "nas-bench-201" in args.search.search_space: from search_spaces.nasbench_201.model_search import ( NASBENCH201Network as Network, ) elif args.search.search_space == "pcdarts": from search_spaces.pc_darts.model_search import PCDARTSNetwork as Network else: raise NotImplementedError if args.train.smooth_cross_entropy: criterion = train_utils.cross_entropy_with_label_smoothing else: criterion = nn.CrossEntropyLoss() num_train, num_classes, train_queue, valid_queue = train_utils.create_data_queues( args) print("dataset: {}, num_classes: {}".format(args.run.dataset, num_classes)) model = Network( args.train.init_channels, num_classes, args.search.nodes, args.train.layers, criterion, **{ "auxiliary": args.train.auxiliary, "search_space_name": args.search.search_space, "exclude_zero": args.search.exclude_zero, "track_running_stats": args.search.track_running_stats, }) model = model.cuda() logging.info("param size = %fMB", train_utils.count_parameters_in_MB(model)) optimizer, scheduler = train_utils.setup_optimizer(model, args) # TODO: separate args by model, architect, etc # TODO: look into using hydra for config files architect = Architect(model, args, writer) # Try to load a previous checkpoint try: start_epochs, history = train_utils.load(save_dir, rng_seed, model, optimizer, architect, args.run.s3_bucket) scheduler.last_epoch = start_epochs - 1 ( num_train, num_classes, train_queue, valid_queue, ) = train_utils.create_data_queues(args) except Exception as e: logging.info(e) start_epochs = 0 best_valid = 0 for epoch in range(start_epochs, args.run.epochs): lr = scheduler.get_lr()[0] logging.info("epoch %d lr %e", epoch, lr) model.drop_path_prob = args.train.drop_path_prob * epoch / args.run.epochs # training train_acc, train_obj = train( args, train_queue, valid_queue, model, architect, criterion, optimizer, lr, ) architect.baseline = train_obj architect.update_history() architect.log_vars(epoch, writer) if "update_lr_state" in dir(scheduler): scheduler.update_lr_state(train_obj) logging.info("train_acc %f", train_acc) # History tracking for vs in [("alphas", architect.alphas), ("edges", architect.edges)]: for ct in vs[1]: v = vs[1][ct] logging.info("{}-{}".format(vs[0], ct)) logging.info(v) # Calling genotypes sets alphas to best arch for EGDAS and MEGDAS # so calling here before infer. genotype = architect.genotype() logging.info("genotype = %s", genotype) if not args.search.single_level: valid_acc, valid_obj = train_utils.infer( valid_queue, model, criterion, report_freq=args.run.report_freq, discrete=args.search.discrete, ) if valid_acc > best_valid: best_valid = valid_acc best_genotype = architect.genotype() logging.info("valid_acc %f", valid_acc) train_utils.save( save_dir, epoch + 1, rng_seed, model, optimizer, architect, save_history=True, s3_bucket=args.run.s3_bucket, ) scheduler.step() valid_acc, valid_obj = train_utils.infer( valid_queue, model, criterion, report_freq=args.run.report_freq, discrete=args.search.discrete, ) if valid_acc > best_valid: best_valid = valid_acc best_genotype = architect.genotype() logging.info("valid_acc %f", valid_acc) if args.run.s3_bucket is not None: filename = "cnn_genotypes.txt" aws_utils.download_from_s3(filename, args.run.s3_bucket, filename) with open(filename, "a+") as f: f.write("\n") f.write("{}{}{}{} = {}".format( args.search.search_space, args.search.method, args.run.dataset.replace("-", ""), args.run.seed, best_genotype, )) aws_utils.upload_to_s3(filename, args.run.s3_bucket, filename) aws_utils.upload_to_s3(log, args.run.s3_bucket, log)