示例#1
0
def main():
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # config network and criterion ################
    min_kept = int(config.batch_size * config.image_height *
                   config.image_width // (16 * config.gt_down_sampling**2))
    ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                            thresh=0.7,
                                            min_kept=min_kept,
                                            use_weight=False)
    distill_criterion = nn.KLDivLoss()

    # data loader ###########################
    if config.is_test:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_eval_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }
    else:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }

    train_loader = get_train_loader(config, Cityscapes, test=config.is_test)

    # Model #######################################
    models = []
    evaluators = []
    testers = []
    lasts = []
    for idx, arch_idx in enumerate(config.arch_idx):
        if config.load_epoch == "last":
            state = torch.load(
                os.path.join(config.load_path, "arch_%d.pt" % arch_idx))
        else:
            state = torch.load(
                os.path.join(
                    config.load_path,
                    "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch))))

        model = Network([
            state["alpha_%d_0" % arch_idx].detach(),
            state["alpha_%d_1" % arch_idx].detach(),
            state["alpha_%d_2" % arch_idx].detach()
        ], [
            None, state["beta_%d_1" % arch_idx].detach(),
            state["beta_%d_2" % arch_idx].detach()
        ], [
            state["ratio_%d_0" % arch_idx].detach(),
            state["ratio_%d_1" % arch_idx].detach(),
            state["ratio_%d_2" % arch_idx].detach()
        ],
                        num_classes=config.num_classes,
                        layers=config.layers,
                        Fch=config.Fch,
                        width_mult_list=config.width_mult_list,
                        stem_head_width=config.stem_head_width[idx],
                        ignore_skip=arch_idx == 0)

        mIoU02 = state["mIoU02"]
        latency02 = state["latency02"]
        obj02 = objective_acc_lat(mIoU02, latency02)
        mIoU12 = state["mIoU12"]
        latency12 = state["latency12"]
        obj12 = objective_acc_lat(mIoU12, latency12)
        if obj02 > obj12: last = [2, 0]
        else: last = [2, 1]
        lasts.append(last)
        model.build_structure(last)
        logging.info("net: " + str(model))
        for b in last:
            if len(config.width_mult_list) > 1:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        width=getattr(model, "widths%d" % b),
                        head_width=config.stem_head_width[idx][1],
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
            else:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
        plot_path_width(model.lasts, model.paths, model.widths).savefig(
            os.path.join(config.save, "path_width%d.png" % arch_idx))
        plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0],
                        [model.widths2, model.widths1, model.widths0]).savefig(
                            os.path.join(config.save,
                                         "path_width_all%d.png" % arch_idx))
        flops, params = profile(model,
                                inputs=(torch.randn(1, 3, 1024, 2048), ))
        logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
        logging.info("ops:" + str(model.ops))
        logging.info("path:" + str(model.paths))
        logging.info("last:" + str(model.lasts))
        model = model.cuda()
        init_weight(model,
                    nn.init.kaiming_normal_,
                    torch.nn.BatchNorm2d,
                    config.bn_eps,
                    config.bn_momentum,
                    mode='fan_in',
                    nonlinearity='relu')

        if arch_idx == 0 and len(config.arch_idx) > 1:
            partial = torch.load(
                os.path.join(config.teacher_path, "weights%d.pt" % arch_idx))
            state = model.state_dict()
            pretrained_dict = {k: v for k, v in partial.items() if k in state}
            state.update(pretrained_dict)
            model.load_state_dict(state)
        elif config.is_eval:
            partial = torch.load(
                os.path.join(config.eval_path, "weights%d.pt" % arch_idx))
            state = model.state_dict()
            pretrained_dict = {k: v for k, v in partial.items() if k in state}
            state.update(pretrained_dict)
            model.load_state_dict(state)

        evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                                 config.num_classes,
                                 config.image_mean,
                                 config.image_std,
                                 model,
                                 config.eval_scale_array,
                                 config.eval_flip,
                                 0,
                                 out_idx=0,
                                 config=config,
                                 verbose=False,
                                 save_path=None,
                                 show_image=False)
        evaluators.append(evaluator)
        tester = SegTester(Cityscapes(data_setting, 'test', None),
                           config.num_classes,
                           config.image_mean,
                           config.image_std,
                           model,
                           config.eval_scale_array,
                           config.eval_flip,
                           0,
                           out_idx=0,
                           config=config,
                           verbose=False,
                           save_path=None,
                           show_image=False)
        testers.append(tester)

        # Optimizer ###################################
        base_lr = config.lr
        if arch_idx == 1 or len(config.arch_idx) == 1:
            # optimize teacher solo OR student (w. distill from teacher)
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=base_lr,
                                        momentum=config.momentum,
                                        weight_decay=config.weight_decay)
        models.append(model)

    # Cityscapes ###########################################
    if config.is_eval:
        logging.info(config.load_path)
        logging.info(config.eval_path)
        logging.info(config.save)
        # validation
        print("[validation...]")
        with torch.no_grad():
            valid_mIoUs = infer(models, evaluators, logger)
            for idx, arch_idx in enumerate(config.arch_idx):
                if arch_idx == 0:
                    logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], 0)
                    logging.info("teacher's valid_mIoU %.3f" %
                                 (valid_mIoUs[idx]))
                else:
                    logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], 0)
                    logging.info("student's valid_mIoU %.3f" %
                                 (valid_mIoUs[idx]))
        exit(0)

    tbar = tqdm(range(config.nepochs), ncols=80)
    for epoch in tbar:
        logging.info(config.load_path)
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))
        # training
        tbar.set_description("[Epoch %d/%d][train...]" %
                             (epoch + 1, config.nepochs))
        train_mIoUs = train(train_loader, models, ohem_criterion,
                            distill_criterion, optimizer, logger, epoch)
        torch.cuda.empty_cache()
        for idx, arch_idx in enumerate(config.arch_idx):
            if arch_idx == 0:
                logger.add_scalar("mIoU/train_teacher", train_mIoUs[idx],
                                  epoch)
                logging.info("teacher's train_mIoU %.3f" % (train_mIoUs[idx]))
            else:
                logger.add_scalar("mIoU/train_student", train_mIoUs[idx],
                                  epoch)
                logging.info("student's train_mIoU %.3f" % (train_mIoUs[idx]))
        adjust_learning_rate(base_lr, 0.992, optimizer, epoch + 1,
                             config.nepochs)

        # validation
        if not config.is_test and ((epoch + 1) % 10 == 0 or epoch == 0):
            tbar.set_description("[Epoch %d/%d][validation...]" %
                                 (epoch + 1, config.nepochs))
            with torch.no_grad():
                valid_mIoUs = infer(models, evaluators, logger)
                for idx, arch_idx in enumerate(config.arch_idx):
                    if arch_idx == 0:
                        logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx],
                                          epoch)
                        logging.info("teacher's valid_mIoU %.3f" %
                                     (valid_mIoUs[idx]))
                    else:
                        logger.add_scalar("mIoU/val_student", valid_mIoUs[idx],
                                          epoch)
                        logging.info("student's valid_mIoU %.3f" %
                                     (valid_mIoUs[idx]))
                    save(models[idx],
                         os.path.join(config.save, "weights%d.pt" % arch_idx))
        # test
        if config.is_test and (epoch + 1) >= 250 and (epoch + 1) % 10 == 0:
            tbar.set_description("[Epoch %d/%d][test...]" %
                                 (epoch + 1, config.nepochs))
            with torch.no_grad():
                test(epoch, models, testers, logger)

        for idx, arch_idx in enumerate(config.arch_idx):
            save(models[idx],
                 os.path.join(config.save, "weights%d.pt" % arch_idx))
示例#2
0
with Engine(custom_parser=parser) as engine:
    args = parser.parse_args()

    cudnn.benchmark = True
    if engine.distributed:
        torch.cuda.set_device(engine.local_rank)

    # data loader
    train_loader, train_sampler = get_train_loader(engine, Cityscapes)

    # config network and criterion
    min_kept = int(config.batch_size // len(engine.devices) *
                   config.image_height * config.image_width // 64)
    criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                       thresh=0.7,
                                       min_kept=min_kept,
                                       use_weight=False)

    if engine.distributed:
        logger.info('Use the Multi-Process-SyncBatchNorm')
        BatchNorm2d = SyncBatchNorm
    # else:
    #     BatchNorm2d = BatchNorm2d
    model = CPNet(config.num_classes,
                  criterion=criterion,
                  pretrained_model=config.pretrained_model,
                  norm_layer=BatchNorm2d)
    init_weight(model.business_layer,
                nn.init.kaiming_normal_,
                BatchNorm2d,
                config.bn_eps,
示例#3
0
def main(pretrain=True):
    config.save = 'search-{}-{}'.format(config.save,
                                        time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    assert type(pretrain) == bool or type(pretrain) == str
    update_arch = True
    if pretrain == True:
        update_arch = False
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # config network and criterion ################
    min_kept = int(config.batch_size * config.image_height *
                   config.image_width // (16 * config.gt_down_sampling**2))
    ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                            thresh=0.7,
                                            min_kept=min_kept,
                                            use_weight=False)

    # Model #######################################
    model = Network(config.num_classes,
                    config.layers,
                    ohem_criterion,
                    Fch=config.Fch,
                    width_mult_list=config.width_mult_list,
                    prun_modes=config.prun_modes,
                    stem_head_width=config.stem_head_width)
    flops, params = profile(model,
                            inputs=(torch.randn(1, 3, 1024, 2048), ),
                            verbose=False)
    logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
    model = model.cuda()
    if type(pretrain) == str:
        partial = torch.load(pretrain + "/weights.pt", map_location='cuda:0')
        state = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in partial.items()
            if k in state and state[k].size() == partial[k].size()
        }
        state.update(pretrained_dict)
        model.load_state_dict(state)
    else:
        init_weight(model,
                    nn.init.kaiming_normal_,
                    nn.BatchNorm2d,
                    config.bn_eps,
                    config.bn_momentum,
                    mode='fan_in',
                    nonlinearity='relu')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    architect = Architect(model, config)

    # Optimizer ###################################
    base_lr = config.lr
    parameters = []
    parameters += list(model.stem.parameters())
    parameters += list(model.cells.parameters())
    parameters += list(model.refine32.parameters())
    parameters += list(model.refine16.parameters())
    parameters += list(model.head0.parameters())
    parameters += list(model.head1.parameters())
    parameters += list(model.head2.parameters())
    parameters += list(model.head02.parameters())
    parameters += list(model.head12.parameters())
    optimizer = torch.optim.SGD(parameters,
                                lr=base_lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

    # lr policy ##############################
    lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.978)

    # data loader ###########################
    data_setting = {
        'img_root': config.img_root_folder,
        'gt_root': config.gt_root_folder,
        'train_source': config.train_source,
        'eval_source': config.eval_source,
        'down_sampling': config.down_sampling
    }
    train_loader_model = get_train_loader(config,
                                          EGTEA,
                                          portion=config.train_portion)
    train_loader_arch = get_train_loader(config,
                                         EGTEA,
                                         portion=config.train_portion - 1)

    evaluator = SegEvaluator(EGTEA(data_setting, 'val', None),
                             config.num_classes,
                             config.image_mean,
                             config.image_std,
                             model,
                             config.eval_scale_array,
                             config.eval_flip,
                             0,
                             config=config,
                             verbose=False,
                             save_path=None,
                             show_image=False)

    if update_arch:
        for idx in range(len(config.latency_weight)):
            logger.add_scalar("arch/latency_weight%d" % idx,
                              config.latency_weight[idx], 0)
            logging.info("arch_latency_weight%d = " % idx +
                         str(config.latency_weight[idx]))

    tbar = tqdm(range(config.nepochs), ncols=80)
    valid_mIoU_history = []
    FPSs_history = []
    latency_supernet_history = []
    latency_weight_history = []
    valid_names = ["8s", "16s", "32s", "8s_32s", "16s_32s"]
    arch_names = {0: "teacher", 1: "student"}
    for epoch in tbar:
        logging.info(pretrain)
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))

        logging.info("update arch: " + str(update_arch))

        # training
        tbar.set_description("[Epoch %d/%d][train...]" %
                             (epoch + 1, config.nepochs))
        train(pretrain,
              train_loader_model,
              train_loader_arch,
              model,
              architect,
              ohem_criterion,
              optimizer,
              lr_policy,
              logger,
              epoch,
              update_arch=update_arch)
        torch.cuda.empty_cache()
        lr_policy.step()

        # validation
        tbar.set_description("[Epoch %d/%d][validation...]" %
                             (epoch + 1, config.nepochs))
        with torch.no_grad():
            if pretrain == True:
                model.prun_mode = "min"
                valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False)
                for i in range(5):
                    logger.add_scalar('mIoU/val_min_%s' % valid_names[i],
                                      valid_mIoUs[i], epoch)
                    logging.info("Epoch %d: valid_mIoU_min_%s %.3f" %
                                 (epoch, valid_names[i], valid_mIoUs[i]))
                if len(model._width_mult_list) > 1:
                    model.prun_mode = "max"
                    valid_mIoUs = infer(epoch,
                                        model,
                                        evaluator,
                                        logger,
                                        FPS=False)
                    for i in range(5):
                        logger.add_scalar('mIoU/val_max_%s' % valid_names[i],
                                          valid_mIoUs[i], epoch)
                        logging.info("Epoch %d: valid_mIoU_max_%s %.3f" %
                                     (epoch, valid_names[i], valid_mIoUs[i]))
                    model.prun_mode = "random"
                    valid_mIoUs = infer(epoch,
                                        model,
                                        evaluator,
                                        logger,
                                        FPS=False)
                    for i in range(5):
                        logger.add_scalar(
                            'mIoU/val_random_%s' % valid_names[i],
                            valid_mIoUs[i], epoch)
                        logging.info("Epoch %d: valid_mIoU_random_%s %.3f" %
                                     (epoch, valid_names[i], valid_mIoUs[i]))
            else:
                valid_mIoUss = []
                FPSs = []
                model.prun_mode = None
                for idx in range(len(model._arch_names)):
                    # arch_idx
                    model.arch_idx = idx
                    valid_mIoUs, fps0, fps1 = infer(epoch, model, evaluator,
                                                    logger)
                    valid_mIoUss.append(valid_mIoUs)
                    FPSs.append([fps0, fps1])
                    for i in range(5):
                        # preds
                        logger.add_scalar(
                            'mIoU/val_%s_%s' %
                            (arch_names[idx], valid_names[i]), valid_mIoUs[i],
                            epoch)
                        logging.info("Epoch %d: valid_mIoU_%s_%s %.3f" %
                                     (epoch, arch_names[idx], valid_names[i],
                                      valid_mIoUs[i]))
                    if config.latency_weight[idx] > 0:
                        logger.add_scalar(
                            'Objective/val_%s_8s_32s' % arch_names[idx],
                            objective_acc_lat(valid_mIoUs[3], 1000. / fps0),
                            epoch)
                        logging.info(
                            "Epoch %d: Objective_%s_8s_32s %.3f" %
                            (epoch, arch_names[idx],
                             objective_acc_lat(valid_mIoUs[3], 1000. / fps0)))
                        logger.add_scalar(
                            'Objective/val_%s_16s_32s' % arch_names[idx],
                            objective_acc_lat(valid_mIoUs[4], 1000. / fps1),
                            epoch)
                        logging.info(
                            "Epoch %d: Objective_%s_16s_32s %.3f" %
                            (epoch, arch_names[idx],
                             objective_acc_lat(valid_mIoUs[4], 1000. / fps1)))
                valid_mIoU_history.append(valid_mIoUss)
                FPSs_history.append(FPSs)
                if update_arch:
                    latency_supernet_history.append(architect.latency_supernet)
                latency_weight_history.append(architect.latency_weight)

        save(model, os.path.join(config.save, 'weights.pt'))
        if type(pretrain) == str:
            # contains arch_param names: {"alphas": alphas, "betas": betas, "gammas": gammas, "ratios": ratios}
            for idx, arch_name in enumerate(model._arch_names):
                state = {}
                for name in arch_name['alphas']:
                    state[name] = getattr(model, name)
                for name in arch_name['betas']:
                    state[name] = getattr(model, name)
                for name in arch_name['ratios']:
                    state[name] = getattr(model, name)
                state["mIoU02"] = valid_mIoUs[3]
                state["mIoU12"] = valid_mIoUs[4]
                if pretrain is not True:
                    state["latency02"] = 1000. / fps0
                    state["latency12"] = 1000. / fps1
                torch.save(
                    state,
                    os.path.join(config.save, "arch_%d_%d.pt" % (idx, epoch)))
                torch.save(state,
                           os.path.join(config.save, "arch_%d.pt" % (idx)))

        if update_arch:
            for idx in range(len(config.latency_weight)):
                if config.latency_weight[idx] > 0:
                    if (int(FPSs[idx][0] >= config.FPS_max[idx]) +
                            int(FPSs[idx][1] >= config.FPS_max[idx])) >= 1:
                        architect.latency_weight[idx] /= 2
                    elif (int(FPSs[idx][0] <= config.FPS_min[idx]) +
                          int(FPSs[idx][1] <= config.FPS_min[idx])) > 0:
                        architect.latency_weight[idx] *= 2
                    logger.add_scalar(
                        "arch/latency_weight_%s" % arch_names[idx],
                        architect.latency_weight[idx], epoch + 1)
                    logging.info("arch_latency_weight_%s = " %
                                 arch_names[idx] +
                                 str(architect.latency_weight[idx]))
示例#4
0
    seed = config.seed
    if engine.distributed:
        seed = engine.local_rank
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # data loader
    train_loader, train_sampler = get_train_loader(engine, Cityscapes)

    # config network and criterion
    criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=255)
    ohem_criterion = ProbOhemCrossEntropy2d(
        ignore_label=255,
        thresh=0.7,
        min_kept=int(config.batch_size // len(engine.devices) *
                     config.image_height * config.image_width //
                     (16 * config.gt_down_sampling**2)),
        use_weight=False)

    if engine.distributed:
        BatchNorm2d = SyncBatchNorm

    model = BiSeNet(config.num_classes,
                    is_training=True,
                    criterion=criterion,
                    ohem_criterion=ohem_criterion,
                    pretrained_model=config.pretrained_model,
                    norm_layer=BatchNorm2d)
    init_weight(model.business_layer,
                nn.init.kaiming_normal_,
示例#5
0
def main():
    args, args_text = _parse_args()

    # dist init
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    torch.cuda.set_device(args.local_rank)
    args.world_size = torch.distributed.get_world_size()
    args.local_rank = torch.distributed.get_rank()
    args.save = args.save + args.exp_name

    # detectron2 data loader ###########################
    # det2_args = default_argument_parser().parse_args()
    det2_args = args
    det2_args.config_file = args.det2_cfg
    cfg = setup(det2_args)
    mapper = DatasetMapper(cfg, augmentations=build_sem_seg_train_aug(cfg))
    det2_dataset = iter(build_detection_train_loader(cfg, mapper=mapper))
    det2_val = build_batch_test_loader(cfg, cfg.DATASETS.TEST[0])
    len_det2_train = 20210 // cfg.SOLVER.IMS_PER_BATCH

    if args.local_rank == 0:
        create_exp_dir(args.save,
                       scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
        logger = SummaryWriter(args.save)
        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            format=log_format,
                            datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", str(args))
    else:
        logger = None

    # preparation ################
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # config network and criterion ################
    gt_down_sampling = 1
    min_kept = int(args.batch_size * args.image_height * args.image_width //
                   (16 * gt_down_sampling**2))
    ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                            thresh=0.7,
                                            min_kept=min_kept,
                                            use_weight=False)

    # data loader ###########################

    num_classes = args.num_classes

    with open(args.json_file, 'r') as f:
        # dict_a = json.loads(f, cls=NpEncoder)
        model_dict = json.loads(f.read())

    width_mult_list = [
        4. / 12,
        6. / 12,
        8. / 12,
        10. / 12,
        1.,
    ]
    model = Network(Fch=args.Fch,
                    num_classes=num_classes,
                    stem_head_width=(args.stem_head_width,
                                     args.stem_head_width))
    last = model_dict["lasts"]

    if args.local_rank == 0:
        with torch.cuda.device(0):
            macs, params = get_model_complexity_info(
                model, (3, args.eval_height, args.eval_width),
                as_strings=True,
                print_per_layer_stat=True,
                verbose=True)
            logging.info('{:<30}  {:<8}'.format('Computational complexity: ',
                                                macs))
            logging.info('{:<30}  {:<8}'.format('Number of parameters: ',
                                                params))

        with open(os.path.join(args.save, 'args.yaml'), 'w') as f:
            f.write(args_text)

    init_weight(model,
                nn.init.kaiming_normal_,
                torch.nn.BatchNorm2d,
                args.bn_eps,
                args.bn_momentum,
                mode='fan_in',
                nonlinearity='relu')

    if args.pretrain:
        model.backbone = load_pretrain(model.backbone, args.pretrain)
    model = model.cuda()

    # if args.sync_bn:
    #     if has_apex:
    #         model = apex.parallel.convert_syncbn_model(model)
    #     else:
    #         model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # Optimizer ###################################
    base_lr = args.base_lr

    if args.opt == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=base_lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=base_lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08)
    elif args.opt == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=base_lr,
                                      betas=(0.9, 0.999),
                                      eps=1e-08,
                                      weight_decay=args.weight_decay)
    else:
        optimizer = create_optimizer(args, model)

    if args.sched == "raw":
        lr_scheduler = None
    else:
        max_iteration = args.epochs * len_det2_train
        lr_scheduler = Iter_LR_Scheduler(args, max_iteration, len_det2_train)

    start_epoch = 0
    if os.path.exists(os.path.join(args.save, 'last.pth.tar')):
        args.resume = os.path.join(args.save, 'last.pth.tar')

    if args.resume:
        model_state_file = args.resume
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['start_epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logging.info('Loaded checkpoint (starting from iter {})'.format(
                checkpoint['start_epoch']))

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=None)

    if model_ema:
        eval_model = model_ema.ema
    else:
        eval_model = model

    if has_apex:
        model = DDP(model, delay_allreduce=True)
    else:
        model = DDP(model, device_ids=[args.local_rank])

    best_valid_iou = 0.
    best_epoch = 0
    temp_iou = 0.
    avg_loss = -1

    logging.info("rank: {} world_size: {}".format(args.local_rank,
                                                  args.world_size))
    for epoch in range(start_epoch, args.epochs):
        if args.local_rank == 0:
            logging.info(args.load_path)
            logging.info(args.save)
            logging.info("lr: " + str(optimizer.param_groups[0]['lr']))

        # training
        drop_prob = args.drop_path_prob * epoch / args.epochs
        # model.module.drop_path_prob(drop_prob)

        train_mIoU = train(len_det2_train, det2_dataset, model, model_ema,
                           ohem_criterion, num_classes, lr_scheduler,
                           optimizer, logger, epoch, args, cfg)

        # torch.cuda.empty_cache()

        # if epoch > args.epochs // 3:
        if epoch >= 0:
            temp_iou, avg_loss = validation(det2_val, eval_model,
                                            ohem_criterion, num_classes, args,
                                            cfg)

        torch.cuda.empty_cache()
        if args.local_rank == 0:
            logging.info("Epoch: {} train miou: {:.2f}".format(
                epoch + 1, 100 * train_mIoU))
            if temp_iou > best_valid_iou:
                best_valid_iou = temp_iou
                best_epoch = epoch

                if model_ema is not None:
                    torch.save(
                        {
                            'start_epoch': epoch + 1,
                            'state_dict': model_ema.ema.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            # 'lr_scheduler': lr_scheduler.state_dict(),
                        },
                        os.path.join(args.save, 'best_checkpoint.pth.tar'))
                else:
                    torch.save(
                        {
                            'start_epoch': epoch + 1,
                            'state_dict': model.module.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            # 'lr_scheduler': lr_scheduler.state_dict(),
                        },
                        os.path.join(args.save, 'best_checkpoint.pth.tar'))

            logger.add_scalar("mIoU/val", temp_iou, epoch)
            logging.info("[Epoch %d/%d] valid mIoU %.4f eval loss %.4f" %
                         (epoch + 1, args.epochs, temp_iou, avg_loss))
            logging.info("Best valid mIoU %.4f Epoch %d" %
                         (best_valid_iou, best_epoch))

            if model_ema is not None:
                torch.save(
                    {
                        'start_epoch': epoch + 1,
                        'state_dict': model_ema.ema.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        # 'lr_scheduler': lr_scheduler.state_dict(),
                    },
                    os.path.join(args.save, 'last.pth.tar'))
            else:
                torch.save(
                    {
                        'start_epoch': epoch + 1,
                        'state_dict': model.module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        # 'lr_scheduler': lr_scheduler.state_dict(),
                    },
                    os.path.join(args.save, 'last.pth.tar'))