Exemple #1
0
def main():
    create_exp_dir(config.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh'))

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # Model #######################################
    lasts = []
    for idx, arch_idx in enumerate(config.arch_idx):
        if config.load_epoch == "last":
            state = torch.load(os.path.join(config.load_path, "arch_%d.pt"%arch_idx))
        else:
            state = torch.load(os.path.join(config.load_path, "arch_%d_%d.pt"%(arch_idx, int(config.load_epoch))))

        model = Network(
            [state["alpha_%d_0"%arch_idx].detach(), state["alpha_%d_1"%arch_idx].detach(), state["alpha_%d_2"%arch_idx].detach()],
            [None, state["beta_%d_1"%arch_idx].detach(), state["beta_%d_2"%arch_idx].detach()],
            [state["ratio_%d_0"%arch_idx].detach(), state["ratio_%d_1"%arch_idx].detach(), state["ratio_%d_2"%arch_idx].detach()],
            num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx==0)

        mIoU02 = state["mIoU02"]; latency02 = state["latency02"]; obj02 = objective_acc_lat(mIoU02, latency02)
        mIoU12 = state["mIoU12"]; latency12 = state["latency12"]; obj12 = objective_acc_lat(mIoU12, latency12)
        if obj02 > obj12: last = [2, 0]
        else: last = [2, 1]
        lasts.append(last)
        model.build_structure(last)
        logging.info("net: " + str(model))
        for b in last:
            if len(config.width_mult_list) > 1:
                plot_op(getattr(model, "ops%d"%b), getattr(model, "path%d"%b), width=getattr(model, "widths%d"%b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join(config.save, "ops_%d_%d.png"%(arch_idx,b)), bbox_inches="tight")
            else:
                plot_op(getattr(model, "ops%d"%b), getattr(model, "path%d"%b), F_base=config.Fch).savefig(os.path.join(config.save, "ops_%d_%d.png"%(arch_idx,b)), bbox_inches="tight")
        plot_path_width(model.lasts, model.paths, model.widths).savefig(os.path.join(config.save, "path_width%d.png"%arch_idx))
        plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig(os.path.join(config.save, "path_width_all%d.png"%arch_idx))
        flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048),), verbose=False)
        logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
        logging.info("ops:" + str(model.ops))
        logging.info("path:" + str(model.paths))
        model = model.cuda()
        #####################################################
        print(config.save)
        latency = compute_latency(model, (1, 3, config.image_height, config.image_width))
        logging.info("FPS:" + str(1000./latency))
Exemple #2
0
def main():
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # config network and criterion ################
    min_kept = int(config.batch_size * config.image_height *
                   config.image_width // (16 * config.gt_down_sampling**2))

    # data loader ###########################
    data_setting = {
        'img_root': config.img_root_folder,
        'gt_root': config.gt_root_folder,
        'train_source': config.train_source,
        'eval_source': config.eval_source,
        'down_sampling': config.down_sampling
    }

    # Model #######################################
    models = []
    evaluators = []
    lasts = []
    for idx, arch_idx in enumerate(config.arch_idx):
        if config.load_epoch == "last":
            state = torch.load(
                os.path.join(config.load_path, "arch_%d.pt" % arch_idx))
        else:
            state = torch.load(
                os.path.join(
                    config.load_path,
                    "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch))))

        model = Network([
            state["alpha_%d_0" % arch_idx].detach(),
            state["alpha_%d_1" % arch_idx].detach(),
            state["alpha_%d_2" % arch_idx].detach()
        ], [
            None, state["beta_%d_1" % arch_idx].detach(),
            state["beta_%d_2" % arch_idx].detach()
        ], [
            state["ratio_%d_0" % arch_idx].detach(),
            state["ratio_%d_1" % arch_idx].detach(),
            state["ratio_%d_2" % arch_idx].detach()
        ],
                        num_classes=config.num_classes,
                        layers=config.layers,
                        Fch=config.Fch,
                        width_mult_list=config.width_mult_list,
                        stem_head_width=config.stem_head_width[idx],
                        ignore_skip=arch_idx == 0)

        mIoU02 = state["mIoU02"]
        latency02 = state["latency02"]
        obj02 = objective_acc_lat(mIoU02, latency02)
        mIoU12 = state["mIoU12"]
        latency12 = state["latency12"]
        obj12 = objective_acc_lat(mIoU12, latency12)
        if obj02 > obj12:
            last = [2, 0]
        else:
            last = [2, 1]
        lasts.append(last)
        model.build_structure(last)
        # logging.info("net: " + str(model))
        for b in last:
            if len(config.width_mult_list) > 1:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        width=getattr(model, "widths%d" % b),
                        head_width=config.stem_head_width[idx][1],
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
            else:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
        plot_path_width(model.lasts, model.paths, model.widths).savefig(
            os.path.join(config.save, "path_width%d.png" % arch_idx))
        plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0],
                        [model.widths2, model.widths1, model.widths0]).savefig(
                            os.path.join(config.save,
                                         "path_width_all%d.png" % arch_idx))
        flops, params = profile(model,
                                inputs=(torch.randn(1, 3, 1024, 2048), ),
                                verbose=False)
        logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
        logging.info("ops:" + str(model.ops))
        logging.info("path:" + str(model.paths))
        logging.info("last:" + str(model.lasts))
        model = model.cuda()
        init_weight(model,
                    nn.init.kaiming_normal_,
                    torch.nn.BatchNorm2d,
                    config.bn_eps,
                    config.bn_momentum,
                    mode='fan_in',
                    nonlinearity='relu')

        partial = torch.load(
            os.path.join(config.eval_path, "weights%d.pt" % arch_idx))
        state = model.state_dict()
        pretrained_dict = {k: v for k, v in partial.items() if k in state}
        state.update(pretrained_dict)
        model.load_state_dict(state)

        evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                                 config.num_classes,
                                 config.image_mean,
                                 config.image_std,
                                 model,
                                 config.eval_scale_array,
                                 config.eval_flip,
                                 0,
                                 out_idx=0,
                                 config=config,
                                 verbose=False,
                                 save_path=os.path.join(
                                     config.save, 'predictions'),
                                 show_image=True,
                                 show_prediction=True)
        evaluators.append(evaluator)
        models.append(model)

    # Cityscapes ###########################################
    logging.info(config.load_path)
    logging.info(config.eval_path)
    logging.info(config.save)
    with torch.no_grad():
        # validation
        print("[validation...]")
        valid_mIoUs = infer(models, evaluators, logger=None)
        for idx, arch_idx in enumerate(config.arch_idx):
            if arch_idx == 0:
                logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx]))
            else:
                logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx]))
Exemple #3
0
def main():
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # config network and criterion ################
    min_kept = int(config.batch_size * config.image_height *
                   config.image_width // (16 * config.gt_down_sampling**2))
    ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                            thresh=0.7,
                                            min_kept=min_kept,
                                            use_weight=False)
    distill_criterion = nn.KLDivLoss()

    # data loader ###########################
    if config.is_test:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_eval_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }
    else:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }

    train_loader = get_train_loader(config, Cityscapes, test=config.is_test)

    # Model #######################################
    models = []
    evaluators = []
    testers = []
    lasts = []
    for idx, arch_idx in enumerate(config.arch_idx):
        if config.load_epoch == "last":
            state = torch.load(
                os.path.join(config.load_path, "arch_%d.pt" % arch_idx))
        else:
            state = torch.load(
                os.path.join(
                    config.load_path,
                    "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch))))

        model = Network([
            state["alpha_%d_0" % arch_idx].detach(),
            state["alpha_%d_1" % arch_idx].detach(),
            state["alpha_%d_2" % arch_idx].detach()
        ], [
            None, state["beta_%d_1" % arch_idx].detach(),
            state["beta_%d_2" % arch_idx].detach()
        ], [
            state["ratio_%d_0" % arch_idx].detach(),
            state["ratio_%d_1" % arch_idx].detach(),
            state["ratio_%d_2" % arch_idx].detach()
        ],
                        num_classes=config.num_classes,
                        layers=config.layers,
                        Fch=config.Fch,
                        width_mult_list=config.width_mult_list,
                        stem_head_width=config.stem_head_width[idx],
                        ignore_skip=arch_idx == 0)

        mIoU02 = state["mIoU02"]
        latency02 = state["latency02"]
        obj02 = objective_acc_lat(mIoU02, latency02)
        mIoU12 = state["mIoU12"]
        latency12 = state["latency12"]
        obj12 = objective_acc_lat(mIoU12, latency12)
        if obj02 > obj12: last = [2, 0]
        else: last = [2, 1]
        lasts.append(last)
        model.build_structure(last)
        logging.info("net: " + str(model))
        for b in last:
            if len(config.width_mult_list) > 1:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        width=getattr(model, "widths%d" % b),
                        head_width=config.stem_head_width[idx][1],
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
            else:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
        plot_path_width(model.lasts, model.paths, model.widths).savefig(
            os.path.join(config.save, "path_width%d.png" % arch_idx))
        plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0],
                        [model.widths2, model.widths1, model.widths0]).savefig(
                            os.path.join(config.save,
                                         "path_width_all%d.png" % arch_idx))
        flops, params = profile(model,
                                inputs=(torch.randn(1, 3, 1024, 2048), ))
        logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
        logging.info("ops:" + str(model.ops))
        logging.info("path:" + str(model.paths))
        logging.info("last:" + str(model.lasts))
        model = model.cuda()
        init_weight(model,
                    nn.init.kaiming_normal_,
                    torch.nn.BatchNorm2d,
                    config.bn_eps,
                    config.bn_momentum,
                    mode='fan_in',
                    nonlinearity='relu')

        if arch_idx == 0 and len(config.arch_idx) > 1:
            partial = torch.load(
                os.path.join(config.teacher_path, "weights%d.pt" % arch_idx))
            state = model.state_dict()
            pretrained_dict = {k: v for k, v in partial.items() if k in state}
            state.update(pretrained_dict)
            model.load_state_dict(state)
        elif config.is_eval:
            partial = torch.load(
                os.path.join(config.eval_path, "weights%d.pt" % arch_idx))
            state = model.state_dict()
            pretrained_dict = {k: v for k, v in partial.items() if k in state}
            state.update(pretrained_dict)
            model.load_state_dict(state)

        evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                                 config.num_classes,
                                 config.image_mean,
                                 config.image_std,
                                 model,
                                 config.eval_scale_array,
                                 config.eval_flip,
                                 0,
                                 out_idx=0,
                                 config=config,
                                 verbose=False,
                                 save_path=None,
                                 show_image=False)
        evaluators.append(evaluator)
        tester = SegTester(Cityscapes(data_setting, 'test', None),
                           config.num_classes,
                           config.image_mean,
                           config.image_std,
                           model,
                           config.eval_scale_array,
                           config.eval_flip,
                           0,
                           out_idx=0,
                           config=config,
                           verbose=False,
                           save_path=None,
                           show_image=False)
        testers.append(tester)

        # Optimizer ###################################
        base_lr = config.lr
        if arch_idx == 1 or len(config.arch_idx) == 1:
            # optimize teacher solo OR student (w. distill from teacher)
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=base_lr,
                                        momentum=config.momentum,
                                        weight_decay=config.weight_decay)
        models.append(model)

    # Cityscapes ###########################################
    if config.is_eval:
        logging.info(config.load_path)
        logging.info(config.eval_path)
        logging.info(config.save)
        # validation
        print("[validation...]")
        with torch.no_grad():
            valid_mIoUs = infer(models, evaluators, logger)
            for idx, arch_idx in enumerate(config.arch_idx):
                if arch_idx == 0:
                    logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], 0)
                    logging.info("teacher's valid_mIoU %.3f" %
                                 (valid_mIoUs[idx]))
                else:
                    logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], 0)
                    logging.info("student's valid_mIoU %.3f" %
                                 (valid_mIoUs[idx]))
        exit(0)

    tbar = tqdm(range(config.nepochs), ncols=80)
    for epoch in tbar:
        logging.info(config.load_path)
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))
        # training
        tbar.set_description("[Epoch %d/%d][train...]" %
                             (epoch + 1, config.nepochs))
        train_mIoUs = train(train_loader, models, ohem_criterion,
                            distill_criterion, optimizer, logger, epoch)
        torch.cuda.empty_cache()
        for idx, arch_idx in enumerate(config.arch_idx):
            if arch_idx == 0:
                logger.add_scalar("mIoU/train_teacher", train_mIoUs[idx],
                                  epoch)
                logging.info("teacher's train_mIoU %.3f" % (train_mIoUs[idx]))
            else:
                logger.add_scalar("mIoU/train_student", train_mIoUs[idx],
                                  epoch)
                logging.info("student's train_mIoU %.3f" % (train_mIoUs[idx]))
        adjust_learning_rate(base_lr, 0.992, optimizer, epoch + 1,
                             config.nepochs)

        # validation
        if not config.is_test and ((epoch + 1) % 10 == 0 or epoch == 0):
            tbar.set_description("[Epoch %d/%d][validation...]" %
                                 (epoch + 1, config.nepochs))
            with torch.no_grad():
                valid_mIoUs = infer(models, evaluators, logger)
                for idx, arch_idx in enumerate(config.arch_idx):
                    if arch_idx == 0:
                        logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx],
                                          epoch)
                        logging.info("teacher's valid_mIoU %.3f" %
                                     (valid_mIoUs[idx]))
                    else:
                        logger.add_scalar("mIoU/val_student", valid_mIoUs[idx],
                                          epoch)
                        logging.info("student's valid_mIoU %.3f" %
                                     (valid_mIoUs[idx]))
                    save(models[idx],
                         os.path.join(config.save, "weights%d.pt" % arch_idx))
        # test
        if config.is_test and (epoch + 1) >= 250 and (epoch + 1) % 10 == 0:
            tbar.set_description("[Epoch %d/%d][test...]" %
                                 (epoch + 1, config.nepochs))
            with torch.no_grad():
                test(epoch, models, testers, logger)

        for idx, arch_idx in enumerate(config.arch_idx):
            save(models[idx],
                 os.path.join(config.save, "weights%d.pt" % arch_idx))
Exemple #4
0
def main():
    args, args_text = _parse_args()

    if args.load_path:
        config.load_path = args.load_path

    config.batch_size = args.batch_size
    config.image_height = args.image_height
    config.image_width = args.image_width
    config.eval_height = args.eval_height
    config.eval_width = args.eval_width
    config.Fch = args.Fch
    config.dataset_path = args.data_path
    config.save = args.save

    # preparation ################
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    model_files = glob.glob("Search/1paths/*.json") + glob.glob(
        "Search/2paths/*.json") + glob.glob("Search/3paths/*.json")

    for model_file in model_files:

        with open(model_file, 'r') as f:
            # dict_a = json.loads(f, cls=NpEncoder)
            model_dict = json.loads(f.read())

        model = Network(model_dict["ops"],
                        model_dict["paths"],
                        model_dict["downs"],
                        model_dict["widths"],
                        model_dict["lasts"],
                        num_classes=config.num_classes,
                        layers=config.layers,
                        Fch=config.Fch,
                        width_mult_list=config.width_mult_list,
                        stem_head_width=(args.stem_head_width,
                                         args.stem_head_width))

        if args.local_rank == 0:
            print("net: " + str(model))
            # with torch.cuda.device(0):
            #     macs, params = get_model_complexity_info(model, (3, 1024, 2048), as_strings=True,
            #                                     print_per_layer_stat=True, verbose=True)
            #     logging.info('{:<30}  {:<8}'.format('Computational complexity: ', macs))
            #     logging.info('{:<30}  {:<8}'.format('Number of parameters: ', params))

            flops, params = profile(model,
                                    inputs=(torch.randn(1, 3, 1024, 2048), ),
                                    verbose=False)
            flops = flops / 1e9
            params = params / 1e6
            model_dict['flops'] = flops
            model_dict['params'] = params
            print("params = %fMB, FLOPs = %fGB", params, flops)

        with open(model_file, 'w') as f:
            json.dump(model_dict, f, cls=NpEncoder)
Exemple #5
0
def main():
    args, args_text = _parse_args()

    # dist init
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='tcp://127.0.0.1:26442',
                                         world_size=1,
                                         rank=0)
    config.device = 'cuda:%d' % args.local_rank
    torch.cuda.set_device(args.local_rank)
    args.world_size = torch.distributed.get_world_size()
    args.local_rank = torch.distributed.get_rank()
    logging.info("rank: {} world_size: {}".format(args.local_rank,
                                                  args.world_size))

    if args.local_rank == 0:
        create_exp_dir(config.save,
                       scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
        logger = SummaryWriter(config.save)
        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            format=log_format,
                            datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", str(config))
    else:
        logger = None

    # preparation ################
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # data loader ###########################
    if config.is_test:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_eval_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }
    else:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }

    with open(config.json_file, 'r') as f:
        model_dict = json.loads(f.read())

    model = Network(model_dict["ops"],
                    model_dict["paths"],
                    model_dict["downs"],
                    model_dict["widths"],
                    model_dict["lasts"],
                    num_classes=config.num_classes,
                    layers=config.layers,
                    Fch=config.Fch,
                    width_mult_list=config.width_mult_list,
                    stem_head_width=config.stem_head_width)

    if args.local_rank == 0:
        logging.info("net: " + str(model))
        flops, params = profile(model,
                                inputs=(torch.randn(1, 3, 1024, 2048), ),
                                verbose=False)
        logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
        logging.info("ops:" + str(model.ops))
        logging.info("path:" + str(model.paths))
        logging.info("last:" + str(model.lasts))
        with open(os.path.join(config.save, 'args.yaml'), 'w') as f:
            f.write(args_text)

    model = model.cuda()
    init_weight(model,
                nn.init.kaiming_normal_,
                torch.nn.BatchNorm2d,
                config.bn_eps,
                config.bn_momentum,
                mode='fan_in',
                nonlinearity='relu')

    model = load_pretrain(model, config.model_path)

    # partial = torch.load(config.model_path)
    # state = model.state_dict()
    # pretrained_dict = {k: v for k, v in partial.items() if k in state}
    # state.update(pretrained_dict)
    # model.load_state_dict(state)

    eval_model = model
    evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                             config.num_classes,
                             config.image_mean,
                             config.image_std,
                             eval_model,
                             config.eval_scale_array,
                             config.eval_flip,
                             0,
                             out_idx=0,
                             config=config,
                             verbose=False,
                             save_path=None,
                             show_image=False,
                             show_prediction=False)
    tester = SegTester(Cityscapes(data_setting, 'test', None),
                       config.num_classes,
                       config.image_mean,
                       config.image_std,
                       eval_model,
                       config.eval_scale_array,
                       config.eval_flip,
                       0,
                       out_idx=0,
                       config=config,
                       verbose=False,
                       save_path=None,
                       show_prediction=False)

    # Cityscapes ###########################################
    logging.info(config.model_path)
    logging.info(config.save)
    with torch.no_grad():
        if config.is_test:
            # test
            print("[test...]")
            with torch.no_grad():
                test(0, model, tester, logger)
        else:
            # validation
            print("[validation...]")
            valid_mIoU = infer(model, evaluator, logger)
            logger.add_scalar("mIoU/val", valid_mIoU, 0)
            logging.info("Model valid_mIoU %.3f" % (valid_mIoU))