def regnety(name, pretrained=False, nc=1000): """Constructs a RegNetY model.""" is_valid = name in _REGNETY_URLS.keys() and name in _REGNETY_CFGS.keys() assert is_valid, "RegNetY-{} not found in the model zoo.".format(name) # Construct the model cfg = _REGNETY_CFGS[name] kwargs = { "stem_type": "simple_stem_in", "stem_w": 32, "block_type": "res_bottleneck_block", "ss": [2, 2, 2, 2], "bms": [1.0, 1.0, 1.0, 1.0], "se_r": 0.25, "nc": nc, "ds": cfg["ds"], "ws": cfg["ws"], "gws": [cfg["g"] for _ in range(4)], } model = AnyNet(**kwargs) # Download and load the weights if pretrained: url = os.path.join(_URL_PREFIX, _REGNETY_URLS[name]) ws_path = cache_url(url, _DOWNLOAD_CACHE) checkpoint.load_checkpoint(ws_path, model) return model
def resnext(name, pretrained=False, nc=1000): """Constructs a ResNeXt model.""" is_valid = name in _RESNEXT_URLS.keys() and name in _RESNEXT_CFGS.keys() assert is_valid, "ResNet-{} not found in the model zoo.".format(name) # Construct the model cfg = _RESNEXT_CFGS[name] kwargs = { "stem_type": "res_stem_in", "stem_w": 64, "block_type": "res_bottleneck_block", "ss": [1, 2, 2, 2], "bms": [0.5, 0.5, 0.5, 0.5], "se_r": None, "nc": nc, "ds": cfg["ds"], "ws": [256, 512, 1024, 2048], "gws": [4, 8, 16, 32], } model = AnyNet(**kwargs) # Download and load the weights if pretrained: url = os.path.join(_URL_PREFIX, _RESNEXT_URLS[name]) ws_path = cache_url(url, _DOWNLOAD_CACHE) checkpoint.load_checkpoint(ws_path, model) return model
def test_model(): """Evaluates a trained model.""" # Setup training/testing environment setup_env() # Construct the model model = setup_model() # Load model weights if cfg.TEST.WEIGHTS: checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model) logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) elif checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint.load_checkpoint(last_checkpoint, model) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) else: print("ERROR: NO checkpoint! ") os._exit() # Create data loaders and meters test_loader = loader.construct_test_loader() test_meter = meters.TestMeter(len(test_loader)) # Evaluate the model if cfg.TASK == 'psd' or cfg.TASK == 'fix': result,ce_error=test_epoch_semi(test_loader, model, test_meter, 0) else: result,ce_error=test_epoch_semi(test_loader, model, test_meter, 0) with open(cfg.OUT_DIR+'/result.txt','w') as f: f.write(str(result["top1_err"])+'\n') f.write(str(ce_error[0])+'\n') f.write(str(ce_error[1])+'\n') print(result["top1_err"],ce_error)
def train_kd_model(): """Trains the model.""" # Setup training/testing environment setup_env() # Construct the model, loss_fun, and optimizer model = setup_model() loss_fun = builders.build_loss_fun().cuda() optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and cp.has_checkpoint(): file = cp.get_last_checkpoint() epoch = cp.load_checkpoint(file, model, optimizer) logger.info("Loaded checkpoint from: {}".format(file)) start_epoch = epoch + 1 elif cfg.TRAIN.WEIGHTS: cp.load_checkpoint(cfg.TRAIN.WEIGHTS, model, strict=False) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Create data loaders and meters train_loader = data_loader.construct_train_loader() test_loader = data_loader.construct_test_loader() train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) # Create a GradScaler for mixed precision training scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) best_err = np.inf # Create the teacher model teacher = setup_teacher_model() for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch params = (train_loader, model, loss_fun, optimizer, scaler, train_meter, teacher) train_kd_epoch(*params, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) # Evaluate the model test_epoch(test_loader, model, test_meter, cur_epoch) # Check if checkpoint is best so far (note: should checkpoint meters as well) stats = test_meter.get_epoch_stats(cur_epoch) best = stats["top1_err"] <= best_err best_err = min(stats["top1_err"], best_err) # Save a checkpoint file = cp.save_checkpoint(model, optimizer, cur_epoch, best) logger.info("Wrote checkpoint to: {}".format(file))
def train_model(): """Trains the model.""" # Setup training/testing environment setup_env() # Construct the model, loss_fun, and optimizer model = setup_model() loss_fun = builders.build_loss_fun().cuda() optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Create data loaders and meters train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): if hasattr(cfg, 'search_epoch'): if cur_epoch >= cfg.search_epoch: break # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: stats = test_epoch(test_loader, model, test_meter, cur_epoch) nni.report_intermediate_result(stats['top1_err']) nni.report_final_result(test_meter.min_top1_err)
def train_model(): """Trains the model.""" # Setup training/testing environment setup_env() # Construct the model, ema, loss_fun, and optimizer model = setup_model() ema = deepcopy(model) loss_fun = builders.build_loss_fun().cuda() optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and cp.has_checkpoint(): file = cp.get_last_checkpoint() epoch = cp.load_checkpoint(file, model, ema, optimizer)[0] logger.info("Loaded checkpoint from: {}".format(file)) start_epoch = epoch + 1 elif cfg.TRAIN.WEIGHTS: train_weights = get_weights_file(cfg.TRAIN.WEIGHTS) cp.load_checkpoint(train_weights, model, ema) logger.info("Loaded initial weights from: {}".format(train_weights)) # Create data loaders and meters train_loader = data_loader.construct_train_loader() test_loader = data_loader.construct_test_loader() train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) ema_meter = meters.TestMeter(len(test_loader), "test_ema") # Create a GradScaler for mixed precision training scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch params = (train_loader, model, ema, loss_fun, optimizer, scaler, train_meter) train_epoch(*params, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) net.compute_precise_bn_stats(ema, train_loader) # Evaluate the model test_epoch(test_loader, model, test_meter, cur_epoch) test_epoch(test_loader, ema, ema_meter, cur_epoch) test_err = test_meter.get_epoch_stats(cur_epoch)["top1_err"] ema_err = ema_meter.get_epoch_stats(cur_epoch)["top1_err"] # Save a checkpoint file = cp.save_checkpoint(model, ema, optimizer, cur_epoch, test_err, ema_err) logger.info("Wrote checkpoint to: {}".format(file))
def build_model(name, pretrained=False, cfg_list=()): """Constructs a predefined model (note: loads global config as well).""" # Load the config reset_cfg() config_file = get_config_file(name) cfg.merge_from_file(config_file) cfg.merge_from_list(cfg_list) # Construct model model = builders.build_model() # Load pretrained weights if pretrained: weights_file = get_weights_file(name) cp.load_checkpoint(weights_file, model) return model
def test_model(): """Evaluates a trained model.""" # Setup training/testing environment setup_env() # Construct the model model = setup_model() # Load model weights checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model) logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) # Create data loaders and meters test_loader = loader.construct_test_loader() test_meter = meters.TestMeter(len(test_loader)) # Evaluate the model test_epoch(test_loader, model, test_meter, 0)
def train_model(): """Trains the model.""" # Setup training/testing environment setup_env() # Construct the model, loss_fun, and optimizer model = setup_model() loss_fun = builders.build_loss_fun().cuda() optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Compute precise time if start_epoch == 0 and cfg.PREC_TIME.ENABLED: logger.info("Computing precise time...") prec_time = net.compute_precise_time(model, loss_fun) logger.info(logging.dump_json_stats(prec_time)) net.reset_bn_stats(model) # Create data loaders and meters train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: test_epoch(test_loader, model, test_meter, cur_epoch)
def main(weights=None, replace=None, kd=False, lr=BASE_LEARINING_RATE, epochs=EPOCHS, cfg=None, batch=BATCH_SIZE, eval_only=False, out=None): # run only evaluate if weights are provided if eval_only: print("Run Evaluation w/o Training") dataloader_test = create_dataloader(DATA_PATH, "val", batch) eval(weights, dataloader_test, replace, cfg) return # Build models student_cfg = STUDENT if cfg is None else cfg teacher = None if 'EfficientNet' in student_cfg: print("Build student model ({})".format(student_cfg)) student = effnet(student_cfg, pretrained=False).cuda() if kd: print("Build teacher model ({})".format(TEACHER)) teacher = effnet(TEACHER, pretrained=True).cuda() else: print("Build student model ({})".format(student_cfg)) student = regnety(student_cfg, pretrained=False).cuda() if kd: print("Build teacher model ({})".format(TEACHER)) teacher = regnety(TEACHER, pretrained=True).cuda() # load students weights with the possible weights of the teacher weights = weights if weights is not None else WEIGHTS_FILE print("Load weights from: {}".format(weights)) cp.load_checkpoint(weights, student, replace=replace) # Create data loaders print("Create dataloaders") dataloader_train = create_dataloader(DATA_PATH, "train", batch) dataloader_test = create_dataloader(DATA_PATH, "val", batch) # loss if kd: print("Create L2 Loss function") loss_fn = lambda x, y: torch.sum(torch.pow(x - y, 2)) else: print("Create Cross Entropy Loss function") loss_fn = SoftCrossEntropyLoss().cuda() # run training and evaluation after training print("Start training") kd_train(teacher, student, loss_fn, dataloader_train, dataloader_test, kd=kd, lr=lr, epochs=epochs, batch=batch, out=out)
def test_model(): """Evaluates the model.""" # Setup logging logging.setup_logging() # Show the config logger.info("Config:\n{}".format(cfg)) # Fix the RNG seeds (see RNG comment in core/config.py for discussion) np.random.seed(cfg.RNG_SEED) torch.manual_seed(cfg.RNG_SEED) # Configure the CUDNN backend torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK # Build the model (before the loaders to speed up debugging) model = builders.build_model() logger.info("Model:\n{}".format(model)) logger.info(logging.dump_json_stats(net.complexity(model))) # Compute precise time if cfg.PREC_TIME.ENABLED: logger.info("Computing precise time...") loss_fun = builders.build_loss_fun() prec_time = net.compute_precise_time(model, loss_fun) logger.info(logging.dump_json_stats(prec_time)) net.reset_bn_stats(model) # Load model weights checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model) logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) # Create data loaders test_loader = loader.construct_test_loader() # Create meters test_meter = meters.TestMeter(len(test_loader)) # Evaluate the model test_epoch(test_loader, model, test_meter, 0)
def eval(model_weights, loader, replace=None, cfg=None): cfg_student = STUDENT if cfg is None else cfg print("Start evaluation on {} with weights from {}...".format(cfg_student, model_weights)) meter = meters.TestMeter(len(loader)) if "EfficientNet" in cfg_student: model = effnet(cfg_student, pretrained=False).cuda() else: model = regnety(cfg_student, pretrained=False).cuda() cp.load_checkpoint(model_weights, model, replace=replace) model.eval() meter.reset() start_time = time.time() for cur_iter, (inputs, labels) in enumerate(loader): inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) preds = model(inputs) top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5]) top1_err, top5_err = top1_err.item(), top5_err.item() meter.update_stats(top1_err, top5_err, inputs.size(0)) if (cur_iter + 1) % 100 == 0: print("iter {}/{}".format(cur_iter + 1, len(loader))) print("Total evaluation time: {}s".format(round(time.time() - start_time))) print("**************************************") print("Top1 accuracy: {:.2f}".format(100 - meter.get_epoch_stats(0)["min_top1_err"])) print("**************************************")
def effnet(name, pretrained=False, nc=1000): """Constructs an EfficientNet model.""" is_valid = name in _EN_URLS.keys() and name in _EN_CFGS.keys() assert is_valid, "EfficientNet-{} not found in the model zoo.".format(name) # Construct the model cfg = _EN_CFGS[name] kwargs = { "exp_rs": [1, 6, 6, 6, 6, 6, 6], "se_r": 0.25, "nc": nc, "ss": [1, 2, 2, 2, 1, 2, 1], "ks": [3, 3, 5, 3, 5, 5, 3], "stem_w": cfg["sw"], "ds": cfg["ds"], "ws": cfg["ws"], "head_w": cfg["hw"], } model = EffNet(**kwargs) # Download and load the weights if pretrained: url = os.path.join(_URL_PREFIX, _EN_URLS[name]) ws_path = cache_url(url, _DOWNLOAD_CACHE) checkpoint.load_checkpoint(ws_path, model) return model
def train_model(): """Trains the model.""" # Setup logging logging.setup_logging() # Show the config logger.info("Config:\n{}".format(cfg)) # Fix the RNG seeds (see RNG comment in core/config.py for discussion) np.random.seed(cfg.RNG_SEED) torch.manual_seed(cfg.RNG_SEED) # Configure the CUDNN backend torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK # Build the model (before the loaders to speed up debugging) model = builders.build_model() logger.info("Model:\n{}".format(model)) logger.info(logging.dump_json_stats(net.complexity(model))) # Define the loss function loss_fun = builders.build_loss_fun() # Construct the optimizer optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Compute precise time if start_epoch == 0 and cfg.PREC_TIME.ENABLED: logger.info("Computing precise time...") prec_time = net.compute_precise_time(model, loss_fun) logger.info(logging.dump_json_stats(prec_time)) net.reset_bn_stats(model) # Create data loaders train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() # Create meters train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if checkpoint.is_checkpoint_epoch(cur_epoch): checkpoint_file = checkpoint.save_checkpoint( model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model if is_eval_epoch(cur_epoch): test_epoch(test_loader, model, test_meter, cur_epoch)
def test_ftta_model(corruptions, levels): """Use feed back to fine-tune some part of the model. (with all kind of corruptions)""" all_results = [] for corruption_level in levels: lvl_results = [] for corruption_type in corruptions: cfg.TRAIN.CORRUPTION = corruption_type cfg.TRAIN.LEVEL = corruption_level cfg.TEST.CORRUPTION = corruption_type cfg.TEST.LEVEL = corruption_level # Setup training/testing environment setup_env() # Construct the model, loss_fun, and optimizer model = setup_model() loss_fun = builders.build_loss_fun().cuda() optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model, strict=cfg.TRAIN.LOAD_STRICT) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Create data loaders and meters train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): if cfg.TRAIN.ADAPTATION != 'test_only': if cfg.TRAIN.ADAPTATION == 'update_bn': bn_update(model, train_loader) elif cfg.TRAIN.ADAPTATION == 'min_entropy': # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) bn_update(model, train_loader) # Save a checkpoint if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( model, optimizer, cur_epoch) logger.info( "Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: top1 = test_epoch(test_loader, model, test_meter, cur_epoch) lvl_results.append(top1) all_results.append(lvl_results) for lvl_idx in range(len(all_results)): logger.info("corruption level: {}".format(levels[lvl_idx])) logger.info("corruption types: {}".format(corruptions)) logger.info(all_results[lvl_idx]) # show_parameters(model) return all_results
def train_model(): """Trains the model.""" # Setup training/testing environment setup_env() # Construct the model, loss_fun, and optimizer model = setup_model() loss_fun = builders.build_loss_fun().cuda() if "search" in cfg.MODEL.TYPE: params_w = [v for k, v in model.named_parameters() if "alphas" not in k] params_a = [v for k, v in model.named_parameters() if "alphas" in k] optimizer_w = torch.optim.SGD( params=params_w, lr=cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY, dampening=cfg.OPTIM.DAMPENING, nesterov=cfg.OPTIM.NESTEROV ) if cfg.OPTIM.ARCH_OPTIM == "adam": optimizer_a = torch.optim.Adam( params=params_a, lr=cfg.OPTIM.ARCH_BASE_LR, betas=(0.5, 0.999), weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY ) elif cfg.OPTIM.ARCH_OPTIM == "sgd": optimizer_a = torch.optim.SGD( params=params_a, lr=cfg.OPTIM.ARCH_BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY, dampening=cfg.OPTIM.DAMPENING, nesterov=cfg.OPTIM.NESTEROV ) optimizer = [optimizer_w, optimizer_a] else: optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS)) # Create data loaders and meters if cfg.TRAIN.PORTION < 1: if "search" in cfg.MODEL.TYPE: train_loader = [loader._construct_loader( dataset_name=cfg.TRAIN.DATASET, split=cfg.TRAIN.SPLIT, batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), shuffle=True, drop_last=True, portion=cfg.TRAIN.PORTION, side="l" ), loader._construct_loader( dataset_name=cfg.TRAIN.DATASET, split=cfg.TRAIN.SPLIT, batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), shuffle=True, drop_last=True, portion=cfg.TRAIN.PORTION, side="r" )] else: train_loader = loader._construct_loader( dataset_name=cfg.TRAIN.DATASET, split=cfg.TRAIN.SPLIT, batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), shuffle=True, drop_last=True, portion=cfg.TRAIN.PORTION, side="l" ) test_loader = loader._construct_loader( dataset_name=cfg.TRAIN.DATASET, split=cfg.TRAIN.SPLIT, batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), shuffle=False, drop_last=False, portion=cfg.TRAIN.PORTION, side="r" ) else: train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter l = train_loader[0] if isinstance(train_loader, list) else train_loader train_meter = train_meter_type(len(l)) test_meter = test_meter_type(len(test_loader)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: l = train_loader[0] if isinstance(train_loader, list) else train_loader benchmark.compute_time_full(model, loss_fun, l, test_loader) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: test_epoch(test_loader, model, test_meter, cur_epoch)
def test(): """Evaluates a trained model.""" # Setup training/testing environment setup_env() # Construct the model model = setup_model() # Load model weights checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model) logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) # Create data loaders test_loader = loader.construct_test_loader() dataset = test_loader.dataset # Enable eval mode logs = [] model.eval() for inputs, labels in test_loader: # Transfer the data to the current GPU device inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) # Compute the predictions preds = model(inputs) if cfg.SOFTMAX: preds = F.softmax(preds, dim=1) else: preds = torch.sigmoid(preds) # Abnormal dataset format support if cfg.TRAIN.DATASET == "abnormal": labels = labels.argmax(dim=1) # (batch_size, classes) -> (classes, batch_size) for label, tail in zip(labels.tolist(), preds.tolist()): logs.append([label, tail[0], tail]) imgs = [v["im_path"] for v in dataset._imdb] class_ids = dataset._class_ids assert len(imgs) == len(logs) lines = [] outputs = [] lines.append(":".join(class_ids)) lines.append("{}".format(len(imgs))) lines.append("im_path,label,score,score_1_n") for im_path, (label, score, tail) in zip(imgs, logs): tail = ",".join(["{:.3f}".format(v) for v in tail]) lines.append("{},{},{},{}".format(im_path, label, score, tail)) outputs.append([im_path, class_ids[label], score]) task_name = time.strftime("%m%d%H%M%S") os.makedirs(os.path.join(cfg.OUT_DIR, task_name)) temp_file = "{}/threshold.png".format(task_name) temp_file = os.path.join(cfg.OUT_DIR, temp_file) score_thr = search_thr(logs, s1_thr=2, s2_thr=70, out_file=temp_file) temp_file = "{}/results.csv".format(task_name) temp_file = os.path.join(cfg.OUT_DIR, temp_file) with open(temp_file, "w") as f: f.write("\n".join(lines)) print(temp_file) temp_file = "{}/results.pkl".format(task_name) temp_file = os.path.join(cfg.OUT_DIR, temp_file) with open(temp_file, "wb") as f: pickle.dump(outputs, f) print(temp_file) hardmini(outputs, class_ids, task_name, score_thr) return outputs
def train_model(): """Trains the model.""" # Setup training/testing environment setup_env() # Construct the model, loss_fun, and optimizer model = setup_model() loss_fun = builders.build_loss_fun().cuda() optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Create data loaders and meters if cfg.TEST.DATASET == 'imagenet_dataset' or cfg.TRAIN.DATASET == 'imagenet_dataset': dataset = loader.construct_train_loader() train_loader = dataset.train_loader test_loader = dataset.val_loader else: dataset = None train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() train_meter = meters.TrainMeter(len(train_loader)) test_meter = meters.TestMeter(len(test_loader)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: net.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(test_loader, model, test_meter, cur_epoch) if dataset is not None: logger.info("Reset the dataset") train_loader._dali_iterator.reset() test_loader._dali_iterator.reset() # clear memory if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect()