def time_model(): """Times model and data loader.""" # Setup training/testing environment setup_env() # Construct the model and loss_fun model = setup_model() loss_fun = builders.build_loss_fun().cuda() # Create data loaders train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() # Compute model and loader timings benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
def train_model(): """Trains the model.""" # Setup training/testing environment setup_env() # Construct the model, loss_fun, and optimizer model = setup_model() loss_fun = builders.build_loss_fun().cuda() optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Create data loaders and meters train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(test_loader, model, test_meter, cur_epoch) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect()
def darts_train_model(): """train DARTS model""" setup_env() # Loading search space search_space = build_space() # TODO: fix the complexity function # search_space = setup_model() # Init controller and architect loss_fun = build_loss_fun().cuda() darts_controller = DartsCNNController(search_space, loss_fun) darts_controller.cuda() architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # Load dataset [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) # weights optimizer w_optim = torch.optim.SGD(darts_controller.weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer a_optim = torch.optim.Adam(darts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = lr_scheduler_builder(w_optim) # Init meters train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) # Load checkpoint or initial weights start_epoch = 0 if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = darts_load_checkpoint(last_checkpoint, darts_controller, w_optim, a_optim) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.SEARCH.WEIGHTS: darts_load_checkpoint(cfg.SEARCH.WEIGHTS, darts_controller) logger.info("Loaded initial weights from: {}".format( cfg.SEARCH.WEIGHTS)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(darts_controller, loss_fun, train_, val_) # Setup timer train_timer = Timer() # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) train_timer.tic() for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): lr = lr_scheduler.get_last_lr()[0] train_epoch(train_, val_, darts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = darts_save_checkpoint(darts_controller, w_optim, a_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, darts_controller, val_meter, cur_epoch, writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(darts_controller.genotype()) logger.info( "########################################################") if cfg.SPACE.NAME == "nasbench301": logger.info("Evaluating with nasbench301") EvaluateNasbench(darts_controller.alpha, darts_controller.net, logger, "nasbench301") darts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() train_timer.toc() logger.info("Overall training time (hr) is:{}".format( str(train_timer.total_time)))
def pcdarts_train_model(): """train PC-DARTS model""" setup_env() # Loading search space search_space = build_space() # TODO: fix the complexity function # search_space = setup_model() # Init controller and architect loss_fun = build_loss_fun().cuda() pcdarts_controller = PCDartsCNNController(search_space, loss_fun) pcdarts_controller.cuda() architect = Architect(pcdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # Load dataset train_transform, valid_transform = data_transforms_cifar10(cutout_length=0) train_data = dset.CIFAR10(root=cfg.SEARCH.DATASET, train=True, download=True, transform=train_transform) num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(cfg.SEARCH.SPLIT[0] * num_train)) train_ = torch.utils.data.DataLoader( train_data, batch_size=cfg.SEARCH.BATCH_SIZE, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=2) val_ = torch.utils.data.DataLoader( train_data, batch_size=cfg.SEARCH.BATCH_SIZE, sampler=torch.utils.data.sampler.SubsetRandomSampler( indices[split:num_train]), pin_memory=True, num_workers=2) # weights optimizer w_optim = torch.optim.SGD(pcdarts_controller.weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer a_optim = torch.optim.Adam(pcdarts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = lr_scheduler_builder(w_optim) # Init meters train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) # Load checkpoint or initial weights start_epoch = 0 if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = darts_load_checkpoint(last_checkpoint, pcdarts_controller, w_optim, a_optim) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.SEARCH.WEIGHTS: darts_load_checkpoint(cfg.SEARCH.WEIGHTS, pcdarts_controller) logger.info("Loaded initial weights from: {}".format( cfg.SEARCH.WEIGHTS)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(pcdarts_controller, loss_fun, train_, val_) # Setup timer train_timer = Timer() # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) train_timer.tic() for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): lr = lr_scheduler.get_last_lr()[0] train_epoch(train_, val_, pcdarts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = darts_save_checkpoint(pcdarts_controller, w_optim, a_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, pcdarts_controller, val_meter, cur_epoch, writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(pcdarts_controller.genotype()) logger.info( "########################################################") pcdarts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() train_timer.toc() logger.info("Overall training time (hr) is:{}".format( str(train_timer.total_time)))
def pdarts_train_model(): """train PDARTS model""" num_to_keep = [5, 3, 1] eps_no_archs = [10, 10, 10] drop_rate = [0.1, 0.4, 0.7] add_layers = [0, 6, 12] add_width = [0, 0, 0] scale_factor = 0.2 PRIMITIVES = cfg.SPACE.PRIMITIVES edgs_num = (cfg.SPACE.NODES + 3) * cfg.SPACE.NODES // 2 basic_op = [] for i in range(edgs_num * 2): basic_op.append(PRIMITIVES) setup_env() loss_fun = build_loss_fun().cuda() # Load dataset [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) for sp in range(len(num_to_keep)): # Update info of supernet config cfg.defrost() cfg.SEARCH.add_layers = add_layers[sp] cfg.SEARCH.add_width = add_width[sp] cfg.SEARCH.dropout_rate = float(drop_rate[sp]) cfg.SPACE.BASIC_OP = basic_op # Loading search space search_space = build_space() # TODO: fix the complexity function # search_space = setup_model() # Init controller and architect pdarts_controller = PDartsCNNController(search_space, loss_fun) pdarts_controller.cuda() architect = Architect(pdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # weights optimizer w_optim = torch.optim.SGD(pdarts_controller.subnet_weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer a_optim = torch.optim.Adam(pdarts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = lr_scheduler_builder(w_optim) # Init meters train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) # Load checkpoint or initial weights start_epoch = 0 if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = darts_load_checkpoint(last_checkpoint, pdarts_controller, w_optim, a_optim) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.SEARCH.WEIGHTS: darts_load_checkpoint(cfg.SEARCH.WEIGHTS, pdarts_controller) logger.info("Loaded initial weights from: {}".format( cfg.SEARCH.WEIGHTS)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(pdarts_controller, loss_fun, train_, val_) # TODO: Setup timer # train_timer = Timer() # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): print('cur_epoch', cur_epoch) lr = lr_scheduler.get_last_lr()[0] if cur_epoch < eps_no_archs[sp]: pdarts_controller.update_p( float(drop_rate[sp]) * (cfg.OPTIM.MAX_EPOCH - cur_epoch - 1) / cfg.OPTIM.MAX_EPOCH) train_epoch(train_, val_, pdarts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch, train_arch=False) else: pdarts_controller.update_p( float(drop_rate[sp]) * np.exp(-(cur_epoch - eps_no_archs[sp]) * scale_factor)) train_epoch(train_, val_, pdarts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch, train_arch=True) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = darts_save_checkpoint( pdarts_controller, w_optim, a_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch >= cfg.OPTIM.MAX_EPOCH - 5: # if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, pdarts_controller, val_meter, cur_epoch, writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(pdarts_controller.genotype()) logger.info( "########################################################") pdarts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() print("Top-{} primitive: {}".format( num_to_keep[sp], pdarts_controller.get_topk_op(num_to_keep[sp]))) if sp == len(num_to_keep) - 1: logger.info( "###############final Optimal genotype: {}############") logger.info(pdarts_controller.genotype(final=True)) logger.info( "########################################################") pdarts_controller.print_alphas(logger) logger.info('Restricting skipconnect...') for sks in range(0, 9): max_sk = 8 - sks num_sk = pdarts_controller.get_skip_number() if not num_sk > max_sk: continue while num_sk > max_sk: pdarts_controller.delete_skip() num_sk = pdarts_controller.get_skip_number() logger.info('Number of skip-connect: %d', max_sk) logger.info(pdarts_controller.genotype(final=True)) else: basic_op = pdarts_controller.get_topk_op(num_to_keep[sp]) logger.info("###############final Optimal genotype: {}############") logger.info(pdarts_controller.genotype(final=True)) logger.info("########################################################") pdarts_controller.print_alphas(logger) logger.info('Restricting skipconnect...') for sks in range(0, 9): max_sk = 8 - sks num_sk = pdarts_controller.get_skip_number() if not num_sk > max_sk: continue while num_sk > max_sk: pdarts_controller.delete_skip() num_sk = pdarts_controller.get_skip_number() logger.info('Number of skip-connect: %d', max_sk) logger.info(pdarts_controller.genotype(final=True))