def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) distill_criterion = nn.KLDivLoss() # data loader ########################### if config.is_test: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_eval_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } else: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } train_loader = get_train_loader(config, Cityscapes, test=config.is_test) # Model ####################################### models = [] evaluators = [] testers = [] lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load( os.path.join(config.load_path, "arch_%d.pt" % arch_idx)) else: state = torch.load( os.path.join( config.load_path, "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch)))) model = Network([ state["alpha_%d_0" % arch_idx].detach(), state["alpha_%d_1" % arch_idx].detach(), state["alpha_%d_2" % arch_idx].detach() ], [ None, state["beta_%d_1" % arch_idx].detach(), state["beta_%d_2" % arch_idx].detach() ], [ state["ratio_%d_0" % arch_idx].detach(), state["ratio_%d_1" % arch_idx].detach(), state["ratio_%d_2" % arch_idx].detach() ], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx == 0) mIoU02 = state["mIoU02"] latency02 = state["latency02"] obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"] latency12 = state["latency12"] obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), width=getattr(model, "widths%d" % b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig( os.path.join(config.save, "path_width%d.png" % arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig( os.path.join(config.save, "path_width_all%d.png" % arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), )) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) logging.info("last:" + str(model.lasts)) model = model.cuda() init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') if arch_idx == 0 and len(config.arch_idx) > 1: partial = torch.load( os.path.join(config.teacher_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) elif config.is_eval: partial = torch.load( os.path.join(config.eval_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) evaluators.append(evaluator) tester = SegTester(Cityscapes(data_setting, 'test', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) testers.append(tester) # Optimizer ################################### base_lr = config.lr if arch_idx == 1 or len(config.arch_idx) == 1: # optimize teacher solo OR student (w. distill from teacher) optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) models.append(model) # Cityscapes ########################################### if config.is_eval: logging.info(config.load_path) logging.info(config.eval_path) logging.info(config.save) # validation print("[validation...]") with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], 0) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], 0) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) exit(0) tbar = tqdm(range(config.nepochs), ncols=80) for epoch in tbar: logging.info(config.load_path) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train_mIoUs = train(train_loader, models, ohem_criterion, distill_criterion, optimizer, logger, epoch) torch.cuda.empty_cache() for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/train_teacher", train_mIoUs[idx], epoch) logging.info("teacher's train_mIoU %.3f" % (train_mIoUs[idx])) else: logger.add_scalar("mIoU/train_student", train_mIoUs[idx], epoch) logging.info("student's train_mIoU %.3f" % (train_mIoUs[idx])) adjust_learning_rate(base_lr, 0.992, optimizer, epoch + 1, config.nepochs) # validation if not config.is_test and ((epoch + 1) % 10 == 0 or epoch == 0): tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], epoch) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], epoch) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx)) # test if config.is_test and (epoch + 1) >= 250 and (epoch + 1) % 10 == 0: tbar.set_description("[Epoch %d/%d][test...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): test(epoch, models, testers, logger) for idx, arch_idx in enumerate(config.arch_idx): save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx))
with Engine(custom_parser=parser) as engine: args = parser.parse_args() cudnn.benchmark = True if engine.distributed: torch.cuda.set_device(engine.local_rank) # data loader train_loader, train_sampler = get_train_loader(engine, Cityscapes) # config network and criterion min_kept = int(config.batch_size // len(engine.devices) * config.image_height * config.image_width // 64) criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) if engine.distributed: logger.info('Use the Multi-Process-SyncBatchNorm') BatchNorm2d = SyncBatchNorm # else: # BatchNorm2d = BatchNorm2d model = CPNet(config.num_classes, criterion=criterion, pretrained_model=config.pretrained_model, norm_layer=BatchNorm2d) init_weight(model.business_layer, nn.init.kaiming_normal_, BatchNorm2d, config.bn_eps,
def main(pretrain=True): config.save = 'search-{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) assert type(pretrain) == bool or type(pretrain) == str update_arch = True if pretrain == True: update_arch = False logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) # Model ####################################### model = Network(config.num_classes, config.layers, ohem_criterion, Fch=config.Fch, width_mult_list=config.width_mult_list, prun_modes=config.prun_modes, stem_head_width=config.stem_head_width) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), ), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) model = model.cuda() if type(pretrain) == str: partial = torch.load(pretrain + "/weights.pt", map_location='cuda:0') state = model.state_dict() pretrained_dict = { k: v for k, v in partial.items() if k in state and state[k].size() == partial[k].size() } state.update(pretrained_dict) model.load_state_dict(state) else: init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) architect = Architect(model, config) # Optimizer ################################### base_lr = config.lr parameters = [] parameters += list(model.stem.parameters()) parameters += list(model.cells.parameters()) parameters += list(model.refine32.parameters()) parameters += list(model.refine16.parameters()) parameters += list(model.head0.parameters()) parameters += list(model.head1.parameters()) parameters += list(model.head2.parameters()) parameters += list(model.head02.parameters()) parameters += list(model.head12.parameters()) optimizer = torch.optim.SGD(parameters, lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) # lr policy ############################## lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.978) # data loader ########################### data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'down_sampling': config.down_sampling } train_loader_model = get_train_loader(config, EGTEA, portion=config.train_portion) train_loader_arch = get_train_loader(config, EGTEA, portion=config.train_portion - 1) evaluator = SegEvaluator(EGTEA(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, config=config, verbose=False, save_path=None, show_image=False) if update_arch: for idx in range(len(config.latency_weight)): logger.add_scalar("arch/latency_weight%d" % idx, config.latency_weight[idx], 0) logging.info("arch_latency_weight%d = " % idx + str(config.latency_weight[idx])) tbar = tqdm(range(config.nepochs), ncols=80) valid_mIoU_history = [] FPSs_history = [] latency_supernet_history = [] latency_weight_history = [] valid_names = ["8s", "16s", "32s", "8s_32s", "16s_32s"] arch_names = {0: "teacher", 1: "student"} for epoch in tbar: logging.info(pretrain) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) logging.info("update arch: " + str(update_arch)) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train(pretrain, train_loader_model, train_loader_arch, model, architect, ohem_criterion, optimizer, lr_policy, logger, epoch, update_arch=update_arch) torch.cuda.empty_cache() lr_policy.step() # validation tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): if pretrain == True: model.prun_mode = "min" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_min_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_min_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) if len(model._width_mult_list) > 1: model.prun_mode = "max" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_max_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_max_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) model.prun_mode = "random" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar( 'mIoU/val_random_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_random_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) else: valid_mIoUss = [] FPSs = [] model.prun_mode = None for idx in range(len(model._arch_names)): # arch_idx model.arch_idx = idx valid_mIoUs, fps0, fps1 = infer(epoch, model, evaluator, logger) valid_mIoUss.append(valid_mIoUs) FPSs.append([fps0, fps1]) for i in range(5): # preds logger.add_scalar( 'mIoU/val_%s_%s' % (arch_names[idx], valid_names[i]), valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_%s_%s %.3f" % (epoch, arch_names[idx], valid_names[i], valid_mIoUs[i])) if config.latency_weight[idx] > 0: logger.add_scalar( 'Objective/val_%s_8s_32s' % arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000. / fps0), epoch) logging.info( "Epoch %d: Objective_%s_8s_32s %.3f" % (epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000. / fps0))) logger.add_scalar( 'Objective/val_%s_16s_32s' % arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000. / fps1), epoch) logging.info( "Epoch %d: Objective_%s_16s_32s %.3f" % (epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000. / fps1))) valid_mIoU_history.append(valid_mIoUss) FPSs_history.append(FPSs) if update_arch: latency_supernet_history.append(architect.latency_supernet) latency_weight_history.append(architect.latency_weight) save(model, os.path.join(config.save, 'weights.pt')) if type(pretrain) == str: # contains arch_param names: {"alphas": alphas, "betas": betas, "gammas": gammas, "ratios": ratios} for idx, arch_name in enumerate(model._arch_names): state = {} for name in arch_name['alphas']: state[name] = getattr(model, name) for name in arch_name['betas']: state[name] = getattr(model, name) for name in arch_name['ratios']: state[name] = getattr(model, name) state["mIoU02"] = valid_mIoUs[3] state["mIoU12"] = valid_mIoUs[4] if pretrain is not True: state["latency02"] = 1000. / fps0 state["latency12"] = 1000. / fps1 torch.save( state, os.path.join(config.save, "arch_%d_%d.pt" % (idx, epoch))) torch.save(state, os.path.join(config.save, "arch_%d.pt" % (idx))) if update_arch: for idx in range(len(config.latency_weight)): if config.latency_weight[idx] > 0: if (int(FPSs[idx][0] >= config.FPS_max[idx]) + int(FPSs[idx][1] >= config.FPS_max[idx])) >= 1: architect.latency_weight[idx] /= 2 elif (int(FPSs[idx][0] <= config.FPS_min[idx]) + int(FPSs[idx][1] <= config.FPS_min[idx])) > 0: architect.latency_weight[idx] *= 2 logger.add_scalar( "arch/latency_weight_%s" % arch_names[idx], architect.latency_weight[idx], epoch + 1) logging.info("arch_latency_weight_%s = " % arch_names[idx] + str(architect.latency_weight[idx]))
seed = config.seed if engine.distributed: seed = engine.local_rank torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # data loader train_loader, train_sampler = get_train_loader(engine, Cityscapes) # config network and criterion criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=255) ohem_criterion = ProbOhemCrossEntropy2d( ignore_label=255, thresh=0.7, min_kept=int(config.batch_size // len(engine.devices) * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)), use_weight=False) if engine.distributed: BatchNorm2d = SyncBatchNorm model = BiSeNet(config.num_classes, is_training=True, criterion=criterion, ohem_criterion=ohem_criterion, pretrained_model=config.pretrained_model, norm_layer=BatchNorm2d) init_weight(model.business_layer, nn.init.kaiming_normal_,
def main(): args, args_text = _parse_args() # dist init torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.cuda.set_device(args.local_rank) args.world_size = torch.distributed.get_world_size() args.local_rank = torch.distributed.get_rank() args.save = args.save + args.exp_name # detectron2 data loader ########################### # det2_args = default_argument_parser().parse_args() det2_args = args det2_args.config_file = args.det2_cfg cfg = setup(det2_args) mapper = DatasetMapper(cfg, augmentations=build_sem_seg_train_aug(cfg)) det2_dataset = iter(build_detection_train_loader(cfg, mapper=mapper)) det2_val = build_batch_test_loader(cfg, cfg.DATASETS.TEST[0]) len_det2_train = 20210 // cfg.SOLVER.IMS_PER_BATCH if args.local_rank == 0: create_exp_dir(args.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(args.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(args)) else: logger = None # preparation ################ np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # config network and criterion ################ gt_down_sampling = 1 min_kept = int(args.batch_size * args.image_height * args.image_width // (16 * gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) # data loader ########################### num_classes = args.num_classes with open(args.json_file, 'r') as f: # dict_a = json.loads(f, cls=NpEncoder) model_dict = json.loads(f.read()) width_mult_list = [ 4. / 12, 6. / 12, 8. / 12, 10. / 12, 1., ] model = Network(Fch=args.Fch, num_classes=num_classes, stem_head_width=(args.stem_head_width, args.stem_head_width)) last = model_dict["lasts"] if args.local_rank == 0: with torch.cuda.device(0): macs, params = get_model_complexity_info( model, (3, args.eval_height, args.eval_width), as_strings=True, print_per_layer_stat=True, verbose=True) logging.info('{:<30} {:<8}'.format('Computational complexity: ', macs)) logging.info('{:<30} {:<8}'.format('Number of parameters: ', params)) with open(os.path.join(args.save, 'args.yaml'), 'w') as f: f.write(args_text) init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, args.bn_eps, args.bn_momentum, mode='fan_in', nonlinearity='relu') if args.pretrain: model.backbone = load_pretrain(model.backbone, args.pretrain) model = model.cuda() # if args.sync_bn: # if has_apex: # model = apex.parallel.convert_syncbn_model(model) # else: # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # Optimizer ################################### base_lr = args.base_lr if args.opt == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.opt == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=base_lr, betas=(0.9, 0.999), eps=1e-08) elif args.opt == "adamw": optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay) else: optimizer = create_optimizer(args, model) if args.sched == "raw": lr_scheduler = None else: max_iteration = args.epochs * len_det2_train lr_scheduler = Iter_LR_Scheduler(args, max_iteration, len_det2_train) start_epoch = 0 if os.path.exists(os.path.join(args.save, 'last.pth.tar')): args.resume = os.path.join(args.save, 'last.pth.tar') if args.resume: model_state_file = args.resume if os.path.isfile(model_state_file): checkpoint = torch.load(model_state_file, map_location=torch.device('cpu')) start_epoch = checkpoint['start_epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logging.info('Loaded checkpoint (starting from iter {})'.format( checkpoint['start_epoch'])) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=None) if model_ema: eval_model = model_ema.ema else: eval_model = model if has_apex: model = DDP(model, delay_allreduce=True) else: model = DDP(model, device_ids=[args.local_rank]) best_valid_iou = 0. best_epoch = 0 temp_iou = 0. avg_loss = -1 logging.info("rank: {} world_size: {}".format(args.local_rank, args.world_size)) for epoch in range(start_epoch, args.epochs): if args.local_rank == 0: logging.info(args.load_path) logging.info(args.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) # training drop_prob = args.drop_path_prob * epoch / args.epochs # model.module.drop_path_prob(drop_prob) train_mIoU = train(len_det2_train, det2_dataset, model, model_ema, ohem_criterion, num_classes, lr_scheduler, optimizer, logger, epoch, args, cfg) # torch.cuda.empty_cache() # if epoch > args.epochs // 3: if epoch >= 0: temp_iou, avg_loss = validation(det2_val, eval_model, ohem_criterion, num_classes, args, cfg) torch.cuda.empty_cache() if args.local_rank == 0: logging.info("Epoch: {} train miou: {:.2f}".format( epoch + 1, 100 * train_mIoU)) if temp_iou > best_valid_iou: best_valid_iou = temp_iou best_epoch = epoch if model_ema is not None: torch.save( { 'start_epoch': epoch + 1, 'state_dict': model_ema.ema.state_dict(), 'optimizer': optimizer.state_dict(), # 'lr_scheduler': lr_scheduler.state_dict(), }, os.path.join(args.save, 'best_checkpoint.pth.tar')) else: torch.save( { 'start_epoch': epoch + 1, 'state_dict': model.module.state_dict(), 'optimizer': optimizer.state_dict(), # 'lr_scheduler': lr_scheduler.state_dict(), }, os.path.join(args.save, 'best_checkpoint.pth.tar')) logger.add_scalar("mIoU/val", temp_iou, epoch) logging.info("[Epoch %d/%d] valid mIoU %.4f eval loss %.4f" % (epoch + 1, args.epochs, temp_iou, avg_loss)) logging.info("Best valid mIoU %.4f Epoch %d" % (best_valid_iou, best_epoch)) if model_ema is not None: torch.save( { 'start_epoch': epoch + 1, 'state_dict': model_ema.ema.state_dict(), 'optimizer': optimizer.state_dict(), # 'lr_scheduler': lr_scheduler.state_dict(), }, os.path.join(args.save, 'last.pth.tar')) else: torch.save( { 'start_epoch': epoch + 1, 'state_dict': model.module.state_dict(), 'optimizer': optimizer.state_dict(), # 'lr_scheduler': lr_scheduler.state_dict(), }, os.path.join(args.save, 'last.pth.tar'))