def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh')) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # Model ####################################### lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load(os.path.join(config.load_path, "arch_%d.pt"%arch_idx)) else: state = torch.load(os.path.join(config.load_path, "arch_%d_%d.pt"%(arch_idx, int(config.load_epoch)))) model = Network( [state["alpha_%d_0"%arch_idx].detach(), state["alpha_%d_1"%arch_idx].detach(), state["alpha_%d_2"%arch_idx].detach()], [None, state["beta_%d_1"%arch_idx].detach(), state["beta_%d_2"%arch_idx].detach()], [state["ratio_%d_0"%arch_idx].detach(), state["ratio_%d_1"%arch_idx].detach(), state["ratio_%d_2"%arch_idx].detach()], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx==0) mIoU02 = state["mIoU02"]; latency02 = state["latency02"]; obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"]; latency12 = state["latency12"]; obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d"%b), getattr(model, "path%d"%b), width=getattr(model, "widths%d"%b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join(config.save, "ops_%d_%d.png"%(arch_idx,b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d"%b), getattr(model, "path%d"%b), F_base=config.Fch).savefig(os.path.join(config.save, "ops_%d_%d.png"%(arch_idx,b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig(os.path.join(config.save, "path_width%d.png"%arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig(os.path.join(config.save, "path_width_all%d.png"%arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048),), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) model = model.cuda() ##################################################### print(config.save) latency = compute_latency(model, (1, 3, config.image_height, config.image_width)) logging.info("FPS:" + str(1000./latency))
def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) # data loader ########################### data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'down_sampling': config.down_sampling } # Model ####################################### models = [] evaluators = [] lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load( os.path.join(config.load_path, "arch_%d.pt" % arch_idx)) else: state = torch.load( os.path.join( config.load_path, "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch)))) model = Network([ state["alpha_%d_0" % arch_idx].detach(), state["alpha_%d_1" % arch_idx].detach(), state["alpha_%d_2" % arch_idx].detach() ], [ None, state["beta_%d_1" % arch_idx].detach(), state["beta_%d_2" % arch_idx].detach() ], [ state["ratio_%d_0" % arch_idx].detach(), state["ratio_%d_1" % arch_idx].detach(), state["ratio_%d_2" % arch_idx].detach() ], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx == 0) mIoU02 = state["mIoU02"] latency02 = state["latency02"] obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"] latency12 = state["latency12"] obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) # logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), width=getattr(model, "widths%d" % b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig( os.path.join(config.save, "path_width%d.png" % arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig( os.path.join(config.save, "path_width_all%d.png" % arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), ), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) logging.info("last:" + str(model.lasts)) model = model.cuda() init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') partial = torch.load( os.path.join(config.eval_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=os.path.join( config.save, 'predictions'), show_image=True, show_prediction=True) evaluators.append(evaluator) models.append(model) # Cityscapes ########################################### logging.info(config.load_path) logging.info(config.eval_path) logging.info(config.save) with torch.no_grad(): # validation print("[validation...]") valid_mIoUs = infer(models, evaluators, logger=None) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx]))
def build_segmentation_model_from_cfg(config): """Builds segmentation model with specific configuration. Args: config: the configuration. Returns: A nn.Module segmentation model. """ model_map = { 'deeplabv3': DeepLabV3, 'deeplabv3plus': DeepLabV3Plus, 'panoptic_deeplab': PanopticDeepLab, } model_cfg = { 'deeplabv3': dict( replace_stride_with_dilation=config.MODEL.BACKBONE.DILATION, in_channels=config.MODEL.DECODER.IN_CHANNELS, feature_key=config.MODEL.DECODER.FEATURE_KEY, decoder_channels=config.MODEL.DECODER.DECODER_CHANNELS, atrous_rates=config.MODEL.DECODER.ATROUS_RATES, num_classes=config.DATASET.NUM_CLASSES, semantic_loss=build_loss_from_cfg(config.LOSS.SEMANTIC), semantic_loss_weight=config.LOSS.SEMANTIC.WEIGHT, ), 'deeplabv3plus': dict( replace_stride_with_dilation=config.MODEL.BACKBONE.DILATION, in_channels=config.MODEL.DECODER.IN_CHANNELS, feature_key=config.MODEL.DECODER.FEATURE_KEY, low_level_channels=config.MODEL.DEEPLABV3PLUS.LOW_LEVEL_CHANNELS, low_level_key=config.MODEL.DEEPLABV3PLUS.LOW_LEVEL_KEY, low_level_channels_project=config.MODEL.DEEPLABV3PLUS. LOW_LEVEL_CHANNELS_PROJECT, decoder_channels=config.MODEL.DECODER.DECODER_CHANNELS, atrous_rates=config.MODEL.DECODER.ATROUS_RATES, num_classes=config.DATASET.NUM_CLASSES, semantic_loss=build_loss_from_cfg(config.LOSS.SEMANTIC), semantic_loss_weight=config.LOSS.SEMANTIC.WEIGHT, ), 'panoptic_deeplab': dict( replace_stride_with_dilation=config.MODEL.BACKBONE.DILATION, in_channels=config.MODEL.DECODER.IN_CHANNELS, feature_key=config.MODEL.DECODER.FEATURE_KEY, low_level_channels=config.MODEL.PANOPTIC_DEEPLAB. LOW_LEVEL_CHANNELS, low_level_key=config.MODEL.PANOPTIC_DEEPLAB.LOW_LEVEL_KEY, low_level_channels_project=config.MODEL.PANOPTIC_DEEPLAB. LOW_LEVEL_CHANNELS_PROJECT, decoder_channels=config.MODEL.DECODER.DECODER_CHANNELS, atrous_rates=config.MODEL.DECODER.ATROUS_RATES, num_classes=config.DATASET.NUM_CLASSES, has_instance=config.MODEL.PANOPTIC_DEEPLAB.INSTANCE.ENABLE, instance_low_level_channels_project=config.MODEL.PANOPTIC_DEEPLAB. INSTANCE.LOW_LEVEL_CHANNELS_PROJECT, instance_decoder_channels=config.MODEL.PANOPTIC_DEEPLAB.INSTANCE. DECODER_CHANNELS, instance_head_channels=config.MODEL.PANOPTIC_DEEPLAB.INSTANCE. HEAD_CHANNELS, instance_aspp_channels=config.MODEL.PANOPTIC_DEEPLAB.INSTANCE. ASPP_CHANNELS, instance_num_classes=config.MODEL.PANOPTIC_DEEPLAB.INSTANCE. NUM_CLASSES, instance_class_key=config.MODEL.PANOPTIC_DEEPLAB.INSTANCE. CLASS_KEY, semantic_loss=build_loss_from_cfg(config.LOSS.SEMANTIC), semantic_loss_weight=config.LOSS.SEMANTIC.WEIGHT, center_loss=build_loss_from_cfg(config.LOSS.CENTER), center_loss_weight=config.LOSS.CENTER.WEIGHT, offset_loss=build_loss_from_cfg(config.LOSS.OFFSET), offset_loss_weight=config.LOSS.OFFSET.WEIGHT, ), } # todo: replace the ResNet with NSA-architecture. if config.MODEL.BACKBONE.META == 'resnet': backbone = resnet.__dict__[config.MODEL.BACKBONE.NAME]( pretrained=config.MODEL.BACKBONE.PRETRAINED, replace_stride_with_dilation=model_cfg[ config.MODEL.META_ARCHITECTURE] ['replace_stride_with_dilation']) elif config.MODEL.BACKBONE.META == 'mobilenet_v2': backbone = mobilenet.__dict__[config.MODEL.BACKBONE.NAME]( pretrained=config.MODEL.BACKBONE.PRETRAINED, ) elif config.MODEL.BACKBONE.META == 'mnasnet': backbone = mnasnet.__dict__[config.MODEL.BACKBONE.NAME]( pretrained=config.MODEL.BACKBONE.PRETRAINED, ) elif config.MODEL.BACKBONE.META == 'nas': lasts = [] for idx, arch_idx in enumerate(config.MODEL.NAS.ARCHI): if config.MODEL.NAS.LOAD_EPOCH == "last": state = torch.load( os.path.join(config.MODEL.NAS.LOAD_PATH, "arch_%d.pt" % arch_idx)) else: state = torch.load( os.path.join( config.MODEL.NAS.LOAD_PATH, "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch)))) # TODO: revise the parameters here. according to FasterSeg. backbone = fasterseg.Network_Multi_Path_Infer( [ state["alpha_%d_0" % arch_idx].detach(), state["alpha_%d_1" % arch_idx].detach(), state["alpha_%d_2" % arch_idx].detach() ], [ None, state["beta_%d_1" % arch_idx].detach(), state["beta_%d_2" % arch_idx].detach() ], [ state["ratio_%d_0" % arch_idx].detach(), state["ratio_%d_1" % arch_idx].detach(), state["ratio_%d_2" % arch_idx].detach() ], num_classes=config.MODEL.NAS.NUM_CLASSES, layers=config.MODEL.NAS.LAYERS, Fch=config.MODEL.NAS.FCH, width_mult_list=config.MODEL.NAS.WIDTH_MULT_LIST, stem_head_width=config.MODEL.NAS.STEM_HEAD_WIDTH[idx], ignore_skip=arch_idx == 0) mIoU02 = state["mIoU02"] latency02 = state["latency02"] obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"] latency12 = state["latency12"] obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) backbone.build_structure(last) if arch_idx == 0 and len(config.MODEL.NAS.ARCHI) > 1: partial = torch.load( os.path.join(config.MODEL.NAS.TEACHER_PATH, "weights%d.pt" % arch_idx)) state = backbone.state_dict() pretrained_dict = { k: v for k, v in partial.items() if k in state } state.update(pretrained_dict) backbone.load_state_dict(state) else: raise ValueError( 'Unknown meta backbone {}, please first implement it.'.format( config.MODEL.BACKBONE.META)) model = model_map[config.MODEL.META_ARCHITECTURE]( backbone, **model_cfg[config.MODEL.META_ARCHITECTURE]) # set batchnorm momentum for module in model.modules(): if isinstance(module, torch.nn.BatchNorm2d): module.momentum = config.MODEL.BN_MOMENTUM return model
def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) distill_criterion = nn.KLDivLoss() # data loader ########################### if config.is_test: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_eval_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } else: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } train_loader = get_train_loader(config, Cityscapes, test=config.is_test) # Model ####################################### models = [] evaluators = [] testers = [] lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load( os.path.join(config.load_path, "arch_%d.pt" % arch_idx)) else: state = torch.load( os.path.join( config.load_path, "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch)))) model = Network([ state["alpha_%d_0" % arch_idx].detach(), state["alpha_%d_1" % arch_idx].detach(), state["alpha_%d_2" % arch_idx].detach() ], [ None, state["beta_%d_1" % arch_idx].detach(), state["beta_%d_2" % arch_idx].detach() ], [ state["ratio_%d_0" % arch_idx].detach(), state["ratio_%d_1" % arch_idx].detach(), state["ratio_%d_2" % arch_idx].detach() ], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx == 0) mIoU02 = state["mIoU02"] latency02 = state["latency02"] obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"] latency12 = state["latency12"] obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), width=getattr(model, "widths%d" % b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig( os.path.join(config.save, "path_width%d.png" % arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig( os.path.join(config.save, "path_width_all%d.png" % arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), )) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) logging.info("last:" + str(model.lasts)) model = model.cuda() init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') if arch_idx == 0 and len(config.arch_idx) > 1: partial = torch.load( os.path.join(config.teacher_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) elif config.is_eval: partial = torch.load( os.path.join(config.eval_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) evaluators.append(evaluator) tester = SegTester(Cityscapes(data_setting, 'test', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) testers.append(tester) # Optimizer ################################### base_lr = config.lr if arch_idx == 1 or len(config.arch_idx) == 1: # optimize teacher solo OR student (w. distill from teacher) optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) models.append(model) # Cityscapes ########################################### if config.is_eval: logging.info(config.load_path) logging.info(config.eval_path) logging.info(config.save) # validation print("[validation...]") with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], 0) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], 0) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) exit(0) tbar = tqdm(range(config.nepochs), ncols=80) for epoch in tbar: logging.info(config.load_path) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train_mIoUs = train(train_loader, models, ohem_criterion, distill_criterion, optimizer, logger, epoch) torch.cuda.empty_cache() for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/train_teacher", train_mIoUs[idx], epoch) logging.info("teacher's train_mIoU %.3f" % (train_mIoUs[idx])) else: logger.add_scalar("mIoU/train_student", train_mIoUs[idx], epoch) logging.info("student's train_mIoU %.3f" % (train_mIoUs[idx])) adjust_learning_rate(base_lr, 0.992, optimizer, epoch + 1, config.nepochs) # validation if not config.is_test and ((epoch + 1) % 10 == 0 or epoch == 0): tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], epoch) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], epoch) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx)) # test if config.is_test and (epoch + 1) >= 250 and (epoch + 1) % 10 == 0: tbar.set_description("[Epoch %d/%d][test...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): test(epoch, models, testers, logger) for idx, arch_idx in enumerate(config.arch_idx): save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx))
def main(pretrain=True): config.save = 'search-{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) assert type(pretrain) == bool or type(pretrain) == str update_arch = True if pretrain == True: update_arch = False logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) # Model ####################################### model = Network(config.num_classes, config.layers, ohem_criterion, Fch=config.Fch, width_mult_list=config.width_mult_list, prun_modes=config.prun_modes, stem_head_width=config.stem_head_width) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), ), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) model = model.cuda() if type(pretrain) == str: partial = torch.load(pretrain + "/weights.pt", map_location='cuda:0') state = model.state_dict() pretrained_dict = { k: v for k, v in partial.items() if k in state and state[k].size() == partial[k].size() } state.update(pretrained_dict) model.load_state_dict(state) else: init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) architect = Architect(model, config) # Optimizer ################################### base_lr = config.lr parameters = [] parameters += list(model.stem.parameters()) parameters += list(model.cells.parameters()) parameters += list(model.refine32.parameters()) parameters += list(model.refine16.parameters()) parameters += list(model.head0.parameters()) parameters += list(model.head1.parameters()) parameters += list(model.head2.parameters()) parameters += list(model.head02.parameters()) parameters += list(model.head12.parameters()) optimizer = torch.optim.SGD(parameters, lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) # lr policy ############################## lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.978) # data loader ########################### data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'down_sampling': config.down_sampling } train_loader_model = get_train_loader(config, EGTEA, portion=config.train_portion) train_loader_arch = get_train_loader(config, EGTEA, portion=config.train_portion - 1) evaluator = SegEvaluator(EGTEA(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, config=config, verbose=False, save_path=None, show_image=False) if update_arch: for idx in range(len(config.latency_weight)): logger.add_scalar("arch/latency_weight%d" % idx, config.latency_weight[idx], 0) logging.info("arch_latency_weight%d = " % idx + str(config.latency_weight[idx])) tbar = tqdm(range(config.nepochs), ncols=80) valid_mIoU_history = [] FPSs_history = [] latency_supernet_history = [] latency_weight_history = [] valid_names = ["8s", "16s", "32s", "8s_32s", "16s_32s"] arch_names = {0: "teacher", 1: "student"} for epoch in tbar: logging.info(pretrain) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) logging.info("update arch: " + str(update_arch)) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train(pretrain, train_loader_model, train_loader_arch, model, architect, ohem_criterion, optimizer, lr_policy, logger, epoch, update_arch=update_arch) torch.cuda.empty_cache() lr_policy.step() # validation tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): if pretrain == True: model.prun_mode = "min" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_min_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_min_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) if len(model._width_mult_list) > 1: model.prun_mode = "max" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_max_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_max_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) model.prun_mode = "random" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar( 'mIoU/val_random_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_random_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) else: valid_mIoUss = [] FPSs = [] model.prun_mode = None for idx in range(len(model._arch_names)): # arch_idx model.arch_idx = idx valid_mIoUs, fps0, fps1 = infer(epoch, model, evaluator, logger) valid_mIoUss.append(valid_mIoUs) FPSs.append([fps0, fps1]) for i in range(5): # preds logger.add_scalar( 'mIoU/val_%s_%s' % (arch_names[idx], valid_names[i]), valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_%s_%s %.3f" % (epoch, arch_names[idx], valid_names[i], valid_mIoUs[i])) if config.latency_weight[idx] > 0: logger.add_scalar( 'Objective/val_%s_8s_32s' % arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000. / fps0), epoch) logging.info( "Epoch %d: Objective_%s_8s_32s %.3f" % (epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000. / fps0))) logger.add_scalar( 'Objective/val_%s_16s_32s' % arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000. / fps1), epoch) logging.info( "Epoch %d: Objective_%s_16s_32s %.3f" % (epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000. / fps1))) valid_mIoU_history.append(valid_mIoUss) FPSs_history.append(FPSs) if update_arch: latency_supernet_history.append(architect.latency_supernet) latency_weight_history.append(architect.latency_weight) save(model, os.path.join(config.save, 'weights.pt')) if type(pretrain) == str: # contains arch_param names: {"alphas": alphas, "betas": betas, "gammas": gammas, "ratios": ratios} for idx, arch_name in enumerate(model._arch_names): state = {} for name in arch_name['alphas']: state[name] = getattr(model, name) for name in arch_name['betas']: state[name] = getattr(model, name) for name in arch_name['ratios']: state[name] = getattr(model, name) state["mIoU02"] = valid_mIoUs[3] state["mIoU12"] = valid_mIoUs[4] if pretrain is not True: state["latency02"] = 1000. / fps0 state["latency12"] = 1000. / fps1 torch.save( state, os.path.join(config.save, "arch_%d_%d.pt" % (idx, epoch))) torch.save(state, os.path.join(config.save, "arch_%d.pt" % (idx))) if update_arch: for idx in range(len(config.latency_weight)): if config.latency_weight[idx] > 0: if (int(FPSs[idx][0] >= config.FPS_max[idx]) + int(FPSs[idx][1] >= config.FPS_max[idx])) >= 1: architect.latency_weight[idx] /= 2 elif (int(FPSs[idx][0] <= config.FPS_min[idx]) + int(FPSs[idx][1] <= config.FPS_min[idx])) > 0: architect.latency_weight[idx] *= 2 logger.add_scalar( "arch/latency_weight_%s" % arch_names[idx], architect.latency_weight[idx], epoch + 1) logging.info("arch_latency_weight_%s = " % arch_names[idx] + str(architect.latency_weight[idx]))