def train_model(): """Trains the model.""" # Build the model (before the loaders to speed up debugging) model = model_builder.build_model() log_model_info(model) # Define the loss function loss_fun = losses.get_loss_fun() # Construct the optimizer optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint(): last_checkpoint = cu.get_last_checkpoint() checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: cu.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Compute precise time if start_epoch == 0 and cfg.PREC_TIME.ENABLED: logger.info("Computing precise time...") bu.compute_precise_time(model, loss_fun) nu.reset_bn_stats(model) # Create data loaders train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() # Create meters train_meter = TrainMeter(len(train_loader)) test_meter = TestMeter(len(test_loader)) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: nu.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if cu.is_checkpoint_epoch(cur_epoch): checkpoint_file = cu.save_checkpoint(model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model if is_eval_epoch(cur_epoch): test_epoch(test_loader, model, test_meter, cur_epoch)
def ensemble_train_model(train_loader, val_loader, model, optimizer, cfg): global plot_epoch_xvalues global plot_epoch_yvalues global plot_it_x_values global plot_it_y_values start_epoch = 0 loss_fun = losses.get_loss_fun() # Create meters train_meter = TrainMeter(len(train_loader)) val_meter = ValMeter(len(val_loader)) # Perform the training loop # print("Len(train_loader):{}".format(len(train_loader))) logger.info('Start epoch: {}'.format(start_epoch + 1)) val_set_acc = 0. temp_best_val_acc = 0. temp_best_val_epoch = 0 # Best checkpoint model and optimizer states best_model_state = None best_opt_state = None val_acc_epochs_x = [] val_acc_epochs_y = [] clf_train_iterations = cfg.OPTIM.MAX_EPOCH * int( len(train_loader) / cfg.TRAIN.BATCH_SIZE) clf_change_lr_iter = clf_train_iterations // 25 clf_iter_count = 0 for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_loss, clf_iter_count = train_epoch(train_loader, model, loss_fun, optimizer, train_meter, \ cur_epoch, cfg, clf_iter_count, clf_change_lr_iter, clf_train_iterations) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: nu.compute_precise_bn_stats(model, train_loader) # Model evaluation if is_eval_epoch(cur_epoch): # Original code[PYCLS] passes on testLoader but we want to compute on val Set val_set_err = test_epoch(val_loader, model, val_meter, cur_epoch) val_set_acc = 100. - val_set_err if temp_best_val_acc < val_set_acc: temp_best_val_acc = val_set_acc temp_best_val_epoch = cur_epoch + 1 # Save best model and optimizer state for checkpointing model.eval() best_model_state = model.module.state_dict( ) if cfg.NUM_GPUS > 1 else model.state_dict() best_opt_state = optimizer.state_dict() model.train() # Since we start from 0 epoch val_acc_epochs_x.append(cur_epoch + 1) val_acc_epochs_y.append(val_set_acc) plot_epoch_xvalues.append(cur_epoch + 1) plot_epoch_yvalues.append(train_loss) save_plot_values([plot_epoch_xvalues, plot_epoch_yvalues, plot_it_x_values, plot_it_y_values, val_acc_epochs_x, val_acc_epochs_y],\ ["plot_epoch_xvalues", "plot_epoch_yvalues", "plot_it_x_values", "plot_it_y_values","val_acc_epochs_x","val_acc_epochs_y"], out_dir=cfg.EPISODE_DIR, isDebug=False) logger.info("Successfully logged numpy arrays!!") # Plot arrays plot_arrays(x_vals=plot_epoch_xvalues, y_vals=plot_epoch_yvalues, \ x_name="Epochs", y_name="Loss", dataset_name=cfg.DATASET.NAME, out_dir=cfg.EPISODE_DIR) plot_arrays(x_vals=val_acc_epochs_x, y_vals=val_acc_epochs_y, \ x_name="Epochs", y_name="Validation Accuracy", dataset_name=cfg.DATASET.NAME, out_dir=cfg.EPISODE_DIR) save_plot_values([plot_epoch_xvalues, plot_epoch_yvalues, plot_it_x_values, plot_it_y_values, val_acc_epochs_x, val_acc_epochs_y], \ ["plot_epoch_xvalues", "plot_epoch_yvalues", "plot_it_x_values", "plot_it_y_values","val_acc_epochs_x","val_acc_epochs_y"], out_dir=cfg.EPISODE_DIR) print('Training Epoch: {}/{}\tTrain Loss: {}\tVal Accuracy: {}'.format( cur_epoch + 1, cfg.OPTIM.MAX_EPOCH, round(train_loss, 4), round(val_set_acc, 4))) # Save the best model checkpoint (Episode level) checkpoint_file = cu.save_checkpoint(info="vlBest_acc_"+str(int(temp_best_val_acc)), \ model_state=best_model_state, optimizer_state=best_opt_state, epoch=temp_best_val_epoch, cfg=cfg) print('\nWrote Best Model Checkpoint to: {}\n'.format( checkpoint_file.split('/')[-1])) logger.info('Wrote Best Model Checkpoint to: {}\n'.format(checkpoint_file)) plot_arrays(x_vals=plot_epoch_xvalues, y_vals=plot_epoch_yvalues, \ x_name="Epochs", y_name="Loss", dataset_name=cfg.DATASET.NAME, out_dir=cfg.EPISODE_DIR) plot_arrays(x_vals=plot_it_x_values, y_vals=plot_it_y_values, \ x_name="Iterations", y_name="Loss", dataset_name=cfg.DATASET.NAME, out_dir=cfg.EPISODE_DIR) plot_arrays(x_vals=val_acc_epochs_x, y_vals=val_acc_epochs_y, \ x_name="Epochs", y_name="Validation Accuracy", dataset_name=cfg.DATASET.NAME, out_dir=cfg.EPISODE_DIR) plot_epoch_xvalues = [] plot_epoch_yvalues = [] plot_it_x_values = [] plot_it_y_values = [] best_val_acc = temp_best_val_acc best_val_epoch = temp_best_val_epoch return best_val_acc, best_val_epoch, checkpoint_file
def train_model(writer_train=None, writer_eval=None, is_master=False): """Trains the model.""" # Fit flops/params if cfg.TRAIN.AUTO_MATCH and cfg.RGRAPH.SEED_TRAIN == cfg.RGRAPH.SEED_TRAIN_START: mode = 'flops' # flops or params if cfg.TRAIN.DATASET == 'cifar10': pre_repeat = 15 if cfg.MODEL.TYPE == 'resnet': # ResNet20 stats_baseline = 40813184 elif cfg.MODEL.TYPE == 'mlpnet': # 5-layer MLP. cfg.MODEL.LAYERS exclude stem and head layers if cfg.MODEL.LAYERS == 3: if cfg.RGRAPH.DIM_LIST[0] == 256: stats_baseline = 985600 elif cfg.RGRAPH.DIM_LIST[0] == 512: stats_baseline = 2364416 elif cfg.RGRAPH.DIM_LIST[0] == 1024: stats_baseline = 6301696 elif cfg.MODEL.TYPE == 'cnn': if cfg.MODEL.LAYERS == 3: if cfg.RGRAPH.DIM_LIST[0] == 512: stats_baseline = 806884352 elif cfg.RGRAPH.DIM_LIST[0] == 16: stats_baseline = 1216672 elif cfg.MODEL.LAYERS == 6: if '64d' in cfg.OUT_DIR: stats_baseline = 48957952 elif '16d' in cfg.OUT_DIR: stats_baseline = 3392128 elif cfg.TRAIN.DATASET == 'imagenet': pre_repeat = 9 if cfg.MODEL.TYPE == 'resnet': if 'basic' in cfg.RESNET.TRANS_FUN: # ResNet34 stats_baseline = 3663761408 elif 'sep' in cfg.RESNET.TRANS_FUN: # ResNet34-sep stats_baseline = 553614592 elif 'bottleneck' in cfg.RESNET.TRANS_FUN: # ResNet50 stats_baseline = 4089184256 elif cfg.MODEL.TYPE == 'efficientnet': # EfficientNet stats_baseline = 385824092 elif cfg.MODEL.TYPE == 'cnn': # CNN if cfg.MODEL.LAYERS == 6: if '64d' in cfg.OUT_DIR: stats_baseline = 166438912 cfg.defrost() stats = model_builder.build_model_stats(mode) if stats != stats_baseline: # 1st round: set first stage dim for i in range(pre_repeat): scale = round(math.sqrt(stats_baseline / stats), 2) first = cfg.RGRAPH.DIM_LIST[0] ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST] first = int(round(first * scale)) cfg.RGRAPH.DIM_LIST = [ int(round(first * ratio)) for ratio in ratio_list ] stats = model_builder.build_model_stats(mode) flag_init = 1 if stats < stats_baseline else -1 step = 1 while True: first = cfg.RGRAPH.DIM_LIST[0] ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST] first += flag_init * step cfg.RGRAPH.DIM_LIST = [ int(round(first * ratio)) for ratio in ratio_list ] stats = model_builder.build_model_stats(mode) flag = 1 if stats < stats_baseline else -1 if stats == stats_baseline: break if flag != flag_init: if cfg.RGRAPH.UPPER == False: # make sure the stats is SMALLER than baseline if flag < 0: first = cfg.RGRAPH.DIM_LIST[0] ratio_list = [ dim / first for dim in cfg.RGRAPH.DIM_LIST ] first -= flag_init * step cfg.RGRAPH.DIM_LIST = [ int(round(first * ratio)) for ratio in ratio_list ] break else: if flag > 0: first = cfg.RGRAPH.DIM_LIST[0] ratio_list = [ dim / first for dim in cfg.RGRAPH.DIM_LIST ] first -= flag_init * step cfg.RGRAPH.DIM_LIST = [ int(round(first * ratio)) for ratio in ratio_list ] break # 2nd round: set other stage dim first = cfg.RGRAPH.DIM_LIST[0] ratio_list = [ int(round(dim / first)) for dim in cfg.RGRAPH.DIM_LIST ] stats = model_builder.build_model_stats(mode) flag_init = 1 if stats < stats_baseline else -1 if 'share' not in cfg.RESNET.TRANS_FUN: for i in range(1, len(cfg.RGRAPH.DIM_LIST)): for j in range(ratio_list[i]): cfg.RGRAPH.DIM_LIST[i] += flag_init stats = model_builder.build_model_stats(mode) flag = 1 if stats < stats_baseline else -1 if flag_init != flag: cfg.RGRAPH.DIM_LIST[i] -= flag_init break stats = model_builder.build_model_stats(mode) print('FINAL', cfg.RGRAPH.GROUP_NUM, cfg.RGRAPH.DIM_LIST, stats, stats_baseline, stats < stats_baseline) # Build the model (before the loaders to ease debugging) model = model_builder.build_model() params, flops = log_model_info(model, writer_eval) # Define the loss function loss_fun = losses.get_loss_fun() # Construct the optimizer optimizer = optim.construct_optimizer(model) # Load a checkpoint if applicable start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint(): last_checkpoint = cu.get_checkpoint_last() checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model, optimizer) logger.info('Loaded checkpoint from: {}'.format(last_checkpoint)) if checkpoint_epoch == cfg.OPTIM.MAX_EPOCH: exit() start_epoch = checkpoint_epoch else: start_epoch = checkpoint_epoch + 1 # Create data loaders train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() # Create meters train_meter = TrainMeter(len(train_loader)) test_meter = TestMeter(len(test_loader)) if cfg.ONLINE_FLOPS: model_dummy = model_builder.build_model() IMAGE_SIZE = 224 n_flops, n_params = mu.measure_model(model_dummy, IMAGE_SIZE, IMAGE_SIZE) logger.info('FLOPs: %.2fM, Params: %.2fM' % (n_flops / 1e6, n_params / 1e6)) del (model_dummy) # Perform the training loop logger.info('Start epoch: {}'.format(start_epoch + 1)) # do eval at initialization eval_epoch(test_loader, model, test_meter, -1, writer_eval, params, flops, is_master=is_master) if start_epoch == cfg.OPTIM.MAX_EPOCH: cur_epoch = start_epoch - 1 eval_epoch(test_loader, model, test_meter, cur_epoch, writer_eval, params, flops, is_master=is_master) else: for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch, writer_train, is_master=is_master) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: nu.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if cu.is_checkpoint_epoch(cur_epoch): checkpoint_file = cu.save_checkpoint(model, optimizer, cur_epoch) logger.info('Wrote checkpoint to: {}'.format(checkpoint_file)) # Evaluate the model if is_eval_epoch(cur_epoch): eval_epoch(test_loader, model, test_meter, cur_epoch, writer_eval, params, flops, is_master=is_master)