def main(): args = parse_args() ctx = get_contexts(args.ctx) data_dir, nclass = get_dataset_info(args.data) norm_layer, norm_kwargs = get_bn_layer(args.norm, ctx) model_kwargs = { 'nclass': nclass, 'backbone': args.backbone, 'aux': args.aux, 'base_size': args.base, 'crop_size': args.crop, 'norm_layer': norm_layer, 'norm_kwargs': norm_kwargs, 'dilate': args.dilate, 'pretrained_base': False, } net = get_model_by_name(args.model, model_kwargs, args.checkpoint, ctx=ctx) EvalFactory.eval(net=net, ctx=ctx, data_name=args.data, data_dir=data_dir, mode=args.mode, ms=args.ms, nclass=nclass, save_dir=args.save_dir)
def get_model(cfg: dict, ctx: list): norm_layer, norm_kwargs = get_bn_layer(cfg.get('norm'), ctx) model_kwargs = { 'nclass': get_dataset_info(cfg.get('data_name'))[1], 'backbone': cfg.get('backbone'), 'pretrained_base': cfg.get('backbone_pretrain'), 'aux': cfg.get('aux'), 'crop_size': cfg.get('crop_size'), 'base_size': cfg.get('base_size'), 'dilate': cfg.get('dilate'), 'norm_layer': norm_layer, 'norm_kwargs': norm_kwargs, } model = get_model_by_name(name=cfg.get('model_name'), model_kwargs=model_kwargs, resume=cfg.get('resume'), lr_mult=cfg.get('lr_mult'), ctx=ctx) model.hybridize() return model
C = EasyDict() config = C # model name C.model_name = 'fcn' C.model_dir = weight_path(C.model_name) C.record_dir = record_path(C.model_name) # dataset: # COCO, VOC2012, VOCAug, SBD, PContext, # Cityscapes, CamVid, CamVidFull, Stanford, GATECH, KITTIZhang, KITTIXu, KITTIRos # NYU, SiftFlow, SUNRGBD, ADE20K C.data_name = 'Cityscapes' C.crop = 768 C.base = 2048 C.data_path, C.nclass = get_dataset_info(C.data_name) # network C.backbone = 'resnet18' C.pretrained_base = True C.dilate = False C.norm = 'sbn' C.aux = False C.aux_weight = .7 if C.aux else None C.lr_multiplier = 10 C.resume = None # import os # C.resume = os.path.join(C.model_dir, '*.params') # optimizer: sgd, nag, adam
checkpoint = os.path.join(weight_path(model_name), checkpoint) if not os.path.isfile(checkpoint): raise RuntimeError(f"No model params found at {checkpoint}") return checkpoint if __name__ == '__main__': """args""" args = parse_args() logger = get_logger() """context""" ctx = get_contexts(args.ctx) """ load model """ logger.info( f"Loading model [{args.model}] for [{args.eval}] on {args.data} ...") data_path, nclass = get_dataset_info(args.data) norm_layer, norm_kwargs = get_bn_layer(args.norm, ctx) model_kwargs = { 'nclass': nclass, 'backbone': args.backbone, 'aux': args.aux, 'base_size': args.base, 'crop_size': args.crop, 'pretrained_base': False, 'norm_layer': norm_layer, 'norm_kwargs': norm_kwargs, 'dilate': args.dilate, } model = get_model_by_name(args.model, ctx, model_kwargs) resume = _validate_checkpoint(args.model, args.resume)
def fit(run, ctx, log_interval=5, no_val=False, logger=None): net = FitFactory.get_model(wandb.config, ctx) train_iter, num_train = FitFactory.data_iter( wandb.config.data_name, wandb.config.bs_train, root=get_dataset_info(wandb.config.data_name)[0], split='train', # sometimes would be 'trainval' mode='train', base_size=wandb.config.base_size, crop_size=wandb.config.crop_size) val_iter, num_valid = FitFactory.data_iter( wandb.config.data_name, wandb.config.bs_val, shuffle=False, last_batch='keep', root=get_dataset_info(wandb.config.data_name)[0], split='val', base_size=wandb.config.base_size, crop_size=wandb.config.crop_size) criterion = FitFactory.get_criterion( wandb.config.aux, wandb.config.aux_weight, # focal_kwargs={'alpha': 1.0, 'gamma': 0.5}, # sensitive_kwargs={ # 'nclass': get_dataset_info(wandb.config.data_name)[1], # 'alpha': 1.0, # 'gamma': 1.0} ) trainer = FitFactory.create_trainer(net, wandb.config, iters_per_epoch=len(train_iter)) metric = SegmentationMetric( nclass=get_dataset_info(wandb.config.data_name)[1]) wandb.config.num_train = num_train wandb.config.num_valid = num_valid t_start = get_strftime() logger.info(f'Training start: {t_start}') for k, v in wandb.config.items(): logger.info(f'{k}: {v}') logger.info('-----> end hyper-parameters <-----') wandb.config.start_time = get_strftime() best_score = .0 for epoch in range(wandb.config.epochs): train_loss = .0 tbar = tqdm(train_iter) for i, (data, target) in enumerate(tbar): gpu_datas = split_and_load(data, ctx_list=ctx) gpu_targets = split_and_load(target, ctx_list=ctx) with autograd.record(): loss_gpus = [ criterion(*net(gpu_data), gpu_target) for gpu_data, gpu_target in zip(gpu_datas, gpu_targets) ] for loss in loss_gpus: autograd.backward(loss) trainer.step(wandb.config.bs_train) nd.waitall() loss_temp = .0 # sum up all sample loss for loss in loss_gpus: loss_temp += loss.sum().asscalar() train_loss += (loss_temp / wandb.config.bs_train) tbar.set_description('Epoch %d, training loss %.5f' % (epoch, train_loss / (i + 1))) if (i % log_interval == 0) or (i + 1 == len(train_iter)): wandb.log({ f'train_loss_batch, interval={log_interval}': train_loss / (i + 1) }) wandb.log({ 'train_loss_epoch': train_loss / (len(train_iter) + 1), 'custom_step': epoch }) if not no_val: cudnn_auto_tune(False) val_loss = .0 vbar = tqdm(val_iter) for i, (data, target) in enumerate(vbar): gpu_datas = split_and_load(data=data, ctx_list=ctx, even_split=False) gpu_targets = split_and_load(data=target, ctx_list=ctx, even_split=False) loss_temp = .0 for gpu_data, gpu_target in zip(gpu_datas, gpu_targets): loss_gpu = criterion(*net(gpu_data), gpu_target) loss_temp += loss_gpu.sum().asscalar() metric.update(gpu_target, net.evaluate(gpu_data)) vbar.set_description('Epoch %d, val PA %.4f, mIoU %.4f' % (epoch, metric.get()[0], metric.get()[1])) val_loss += (loss_temp / wandb.config.bs_val) nd.waitall() pix_acc, mean_iou = metric.get() wandb.log({ 'val_PA': pix_acc, 'val_mIoU': mean_iou, 'val_loss': val_loss / len(val_iter) + 1 }) metric.reset() if mean_iou > best_score: save_checkpoint(model=net, model_name=wandb.config.model_name.lower(), backbone=wandb.config.backbone.lower(), data_name=wandb.config.data_name.lower(), time_stamp=wandb.config.start_time, is_best=True) best_score = mean_iou cudnn_auto_tune(True) save_checkpoint(model=net, model_name=wandb.config.model_name.lower(), backbone=wandb.config.backbone.lower(), data_name=wandb.config.data_name.lower(), time_stamp=wandb.config.start_time, is_best=False) run.finish()