Exemplo n.º 1
0
def main():

    #########  configs ###########
    best_metric = 0
    ######  load datasets ########
    train_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    val_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    train_data = dates.Dataset(cfg.TRAIN_DATA_PATH,
                               cfg.TRAIN_LABEL_PATH,
                               cfg.TRAIN_TXT_PATH,
                               'train',
                               transform=True,
                               transform_med=train_transform_det)
    train_loader = Data.DataLoader(train_data,
                                   batch_size=cfg.BATCH_SIZE,
                                   shuffle=True,
                                   num_workers=4,
                                   pin_memory=True)
    val_data = dates.Dataset(cfg.VAL_DATA_PATH,
                             cfg.VAL_LABEL_PATH,
                             cfg.VAL_TXT_PATH,
                             'val',
                             transform=True,
                             transform_med=val_transform_det)
    val_loader = Data.DataLoader(val_data,
                                 batch_size=cfg.BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=4,
                                 pin_memory=True)
    ######  build  models ########
    base_seg_model = 'deeplab'
    if base_seg_model == 'deeplab':
        import model.siameseNet.deeplab_v2 as models
        pretrain_deeplab_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                             'deeplab_v2_voc12.pth')
        model = models.SiameseNet(norm_flag='l2')
        if resume:
            checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT)
            model.load_state_dict(checkpoint['state_dict'])
            print('resume success')
        else:
            deeplab_pretrain_model = torch.load(pretrain_deeplab_path)
            model.init_parameters_from_deeplab(deeplab_pretrain_model)
    else:
        import model.siameseNet.fcn32s_tiny as models
        pretrain_vgg_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                         'vgg16_from_caffe.pth')
        model = models.SiameseNet(distance_flag='softmax')
        if resume:
            checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT)
            model.load_state_dict(checkpoint['state_dict'])
            print('resume success')
        else:
            vgg_pretrain_model = util.load_pretrain_model(pretrain_vgg_path)
            model.init_parameters(vgg_pretrain_model)

    model = model.cuda()
    MaskLoss = ls.ConstractiveMaskLoss()
    ab_test_dir = os.path.join(cfg.SAVE_PRED_PATH, 'contrastive_loss')
    check_dir(ab_test_dir)
    save_change_map_dir = os.path.join(ab_test_dir, 'changemaps/')
    save_valid_dir = os.path.join(ab_test_dir, 'valid_imgs')
    save_roc_dir = os.path.join(ab_test_dir, 'roc')
    check_dir(save_change_map_dir), check_dir(save_valid_dir), check_dir(
        save_roc_dir)
    #########
    ######### optimizer ##########
    ######## how to set different learning rate for differernt layers #########
    optimizer = torch.optim.SGD(
        [{
            'params': set_base_learning_rate_for_multi_layer(model),
            'lr': cfg.INIT_LEARNING_RATE
        }, {
            'params': set_2x_learning_rate_for_multi_layer(model),
            'lr': 2 * cfg.INIT_LEARNING_RATE,
            'weight_decay': 0
        }, {
            'params': set_10x_learning_rate_for_multi_layer(model),
            'lr': 10 * cfg.INIT_LEARNING_RATE
        }, {
            'params': set_20x_learning_rate_for_multi_layer(model),
            'lr': 20 * cfg.INIT_LEARNING_RATE,
            'weight_decay': 0
        }],
        lr=cfg.INIT_LEARNING_RATE,
        momentum=cfg.MOMENTUM,
        weight_decay=cfg.DECAY)
    ######## iter img_label pairs ###########
    loss_total = 0
    for epoch in range(100):
        for batch_idx, batch in enumerate(train_loader):
            step = epoch * len(train_loader) + batch_idx
            util.adjust_learning_rate(cfg.INIT_LEARNING_RATE, optimizer, step)
            model.train()
            img1_idx, img2_idx, label_idx, filename, height, width = batch
            img1, img2, label = Variable(img1_idx.cuda()), Variable(
                img2_idx.cuda()), Variable(label_idx.cuda())
            out_conv5, out_fc, out_embedding = model(img1, img2)
            out_conv5_t0, out_conv5_t1 = out_conv5
            out_fc_t0, out_fc_t1 = out_fc
            out_embedding_t0, out_embedding_t1 = out_embedding
            label_rz_conv5 = Variable(
                util.resize_label(
                    label.data.cpu().numpy(),
                    size=out_conv5_t0.data.cpu().numpy().shape[2:]).cuda())
            label_rz_fc = Variable(
                util.resize_label(
                    label.data.cpu().numpy(),
                    size=out_fc_t0.data.cpu().numpy().shape[2:]).cuda())
            label_rz_embedding = Variable(
                util.resize_label(
                    label.data.cpu().numpy(),
                    size=out_embedding_t0.data.cpu().numpy().shape[2:]).cuda())
            contractive_loss_conv5 = MaskLoss(out_conv5_t0, out_conv5_t1,
                                              label_rz_conv5)
            contractive_loss_fc = MaskLoss(out_fc_t0, out_fc_t1, label_rz_fc)
            contractive_loss_embedding = MaskLoss(out_embedding_t0,
                                                  out_embedding_t1,
                                                  label_rz_embedding)
            loss = contractive_loss_conv5 + contractive_loss_fc + contractive_loss_embedding
            loss_total += loss.data.cpu()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (batch_idx) % 20 == 0:
                print(
                    "Epoch [%d/%d] Loss: %.4f Mask_Loss_conv5: %.4f Mask_Loss_fc: %.4f "
                    "Mask_Loss_embedding: %.4f" %
                    (epoch, batch_idx, loss.data[0],
                     contractive_loss_conv5.data[0],
                     contractive_loss_fc.data[0],
                     contractive_loss_embedding.data[0]))
            if (batch_idx) % 1000 == 0:
                model.eval()
                current_metric = validate(model, val_loader, epoch,
                                          save_change_map_dir, save_roc_dir)
                if current_metric > best_metric:
                    torch.save({'state_dict': model.state_dict()},
                               os.path.join(ab_test_dir,
                                            'model' + str(epoch) + '.pth'))
                    shutil.copy(
                        os.path.join(ab_test_dir,
                                     'model' + str(epoch) + '.pth'),
                        os.path.join(ab_test_dir, 'model_best.pth'))
                    best_metric = current_metric
        current_metric = validate(model, val_loader, epoch,
                                  save_change_map_dir, save_roc_dir)
        if current_metric > best_metric:
            torch.save({'state_dict': model.state_dict()},
                       os.path.join(ab_test_dir,
                                    'model' + str(epoch) + '.pth'))
            shutil.copy(
                os.path.join(ab_test_dir, 'model' + str(epoch) + '.pth'),
                os.path.join(ab_test_dir, 'model_best.pth'))
            best_metric = current_metric
        if epoch % 5 == 0:
            torch.save({'state_dict': model.state_dict()},
                       os.path.join(ab_test_dir,
                                    'model' + str(epoch) + '.pth'))
Exemplo n.º 2
0
def train(config):

    merge_method = config["merge_method"]
    training_stage = config["training_stage"]

    model_list = list()

    def hook(module, inp, outp):
        features.(outp.clone().detach())

    for idx, model_config in enumerate(config["models"]):

        stream = dict()

        # load model architect
        model_name = model_config["model_name"]
        num_classes = model_config["num_classes"]
        if model_name.startswith("efficientnet-b"):
            model = EfficientNet.from_pretrained(model_name,
                num_classes=num_classes)
        else:
            model = timm.create_model(model_name, num_classes=num_classes, pretrained=True)

        # load trained model
        ckpt_path = model_config['checkpoint_path']
        key_ckpt = model_config['key_checkpoint']
        model = load_pretrain_model(model, ckpt_path, key_ckpt)

        # add hook to extract feature(s)
        feature_names = model_config['stream_feature_names']
        feature_dims = model_config['stream_feature_dims']

        features

        for fn, fd in zip(feature_names, feature_dims):
            ftr = eval(f"model.{fn}.register_forward_hook(hook)")

            handles.append(ftr)
            print(ftr)
            handle_dims.append(fd)

        # transform for image preprocess
        image_size = model_config['image_size']
        resize_mode = model_config['resize_mode']

        # TODO

        # dataloader
        data_config = config['data_loader'][idx]
        # TODO

        stream['model'] = model
        stream['hook'] = handles
        stream['dataloader'] = data_loader
        stream['feature_dim'] = handle_dims

        model_list.append(stream)

    sw_config = config['stream_weights']

    input_tensor_shape = 0
    for model in model_list:
        ft_dim = model['feature_dim']
        input_tensor_shape += ft_dim

    # stream network
    stream_net = build_stream_net(sw_config, input_tensor_shape)


    # optimizer & loss
    tr_config = config['train_op']
    optimizer = configure_optimizer(tr_config)
    loss_function = configure_loss(tr_config)

    global_step = 0

    # training
    for epoch in epochs:

        for data, label in data_loader:

            inp = []

            for i in range(len(data)):
                y = model[i](data[i]) # execute forward propagation

                mid = features[i][-1]
                inp.append(mid)

            if merge_method == "concat":
                bla bla
            elif merge_method == "add":
                bla bla
            else:
                pass

            out = stream_net(inp)

            optimizer.zero_grad()

            loss = loss_function(out, label)

            loss.backward()

            optimizer.step()

            global_step += 1
    def internal():
        #########  configs ###########
        best_metric = 0
        ######  load datasets ########
        train_transform_det = trans.Compose([
            trans.Scale(cfg.TRANSFROM_SCALES),
        ])
        val_transform_det = trans.Compose([
            trans.Scale(cfg.TRANSFROM_SCALES),
        ])
        train_data = dates.Dataset(cfg.TRAIN_DATA_PATH,
                                   cfg.TRAIN_LABEL_PATH,
                                   cfg.TRAIN_TXT_PATH,
                                   'train',
                                   transform=True,
                                   transform_med=train_transform_det)
        train_loader = Data.DataLoader(train_data,
                                       batch_size=cfg.BATCH_SIZE,
                                       shuffle=True,
                                       num_workers=4,
                                       pin_memory=True)
        val_data = dates.Dataset(cfg.VAL_DATA_PATH,
                                 cfg.VAL_LABEL_PATH,
                                 cfg.VAL_TXT_PATH,
                                 'val',
                                 transform=True,
                                 transform_med=val_transform_det)
        val_loader = Data.DataLoader(val_data,
                                     batch_size=cfg.BATCH_SIZE,
                                     shuffle=False,
                                     num_workers=4,
                                     pin_memory=True)
        ######  build  models ########
        base_seg_model = 'deeplab'
        if base_seg_model == 'deeplab':
            import model.siameseNet.deeplab_v2 as models
            pretrain_deeplab_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                                 'deeplab_v2_voc12.pth')
            model = models.SiameseNet(norm_flag='l2')
            if resume:
                checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT)
                model.load_state_dict(checkpoint['state_dict'])
                print('resume success')
            else:
                deeplab_pretrain_model = torch.load(pretrain_deeplab_path)
                model.init_parameters_from_deeplab(deeplab_pretrain_model)
        else:
            import model.siameseNet.fcn32s_tiny as models
            pretrain_vgg_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                             'vgg16_from_caffe.pth')
            model = models.SiameseNet(distance_flag='softmax')
            if resume:
                checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT)
                model.load_state_dict(checkpoint['state_dict'])
                print('resume success')
            else:
                vgg_pretrain_model = util.load_pretrain_model(
                    pretrain_vgg_path)
                model.init_parameters(vgg_pretrain_model)

        model = model.cuda()
        MaskLoss = ls.ConstractiveMaskLoss()
        ab_test_dir = os.path.join(cfg.SAVE_PRED_PATH, 'contrastive_loss')
        check_dir(ab_test_dir)
        save_change_map_dir = os.path.join(ab_test_dir, 'changemaps/')
        save_valid_dir = os.path.join(ab_test_dir, 'valid_imgs')
        save_roc_dir = os.path.join(ab_test_dir, 'roc')
        check_dir(save_change_map_dir), check_dir(save_valid_dir), check_dir(
            save_roc_dir)
        #########
        ######### optimizer ##########
        ######## how to set different learning rate for differernt layers #########
        optimizer = torch.optim.SGD(
            [{
                'params': set_base_learning_rate_for_multi_layer(model),
                'lr': cfg.INIT_LEARNING_RATE
            }, {
                'params': set_2x_learning_rate_for_multi_layer(model),
                'lr': 2 * cfg.INIT_LEARNING_RATE,
                'weight_decay': 0
            }, {
                'params': set_10x_learning_rate_for_multi_layer(model),
                'lr': 10 * cfg.INIT_LEARNING_RATE
            }, {
                'params': set_20x_learning_rate_for_multi_layer(model),
                'lr': 20 * cfg.INIT_LEARNING_RATE,
                'weight_decay': 0
            }],
            lr=cfg.INIT_LEARNING_RATE,
            momentum=cfg.MOMENTUM,
            weight_decay=cfg.DECAY)
        ######## iter img_label pairs ###########
        loss_total = 0
Exemplo n.º 4
0
def main():

    #########  configs ###########
    best_metric = 0
    ######  load datasets ########
    train_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    val_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    train_data = dates.Dataset(cfg.TRAIN_DATA_PATH,
                               cfg.TRAIN_LABEL_PATH,
                               cfg.TRAIN_TXT_PATH,
                               'train',
                               transform=True,
                               transform_med=train_transform_det)
    train_loader = Data.DataLoader(train_data,
                                   batch_size=cfg.BATCH_SIZE,
                                   shuffle=True,
                                   num_workers=4,
                                   pin_memory=True)
    val_data = dates.Dataset(cfg.VAL_DATA_PATH,
                             cfg.VAL_LABEL_PATH,
                             cfg.VAL_TXT_PATH,
                             'val',
                             transform=True,
                             transform_med=val_transform_det)
    val_loader = Data.DataLoader(val_data,
                                 batch_size=cfg.BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=4,
                                 pin_memory=True)
    ######  build  models ########
    ### set transition=True gain better performance ####
    base_seg_model = 'deeplab'
    if base_seg_model == 'deeplab':
        import model.siameseNet.deeplab_v2_fusion as models
        pretrain_deeplab_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                             'deeplab_v2_voc12.pth')
        model = models.SiameseNet(class_number=2, norm_flag='exp')
        if resume:
            checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT)
            model.load_state_dict(checkpoint['state_dict'])
            print('resume success')
        else:
            deeplab_pretrain_model = torch.load(pretrain_deeplab_path)
            #deeplab_pretrain_model = util.load_deeplab_pretrain_model(pretrain_deeplab_path)
            model.init_parameters_from_deeplab(deeplab_pretrain_model)
    else:
        import model.later_fusion.fcn32s as models
        pretrain_vgg_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                         'vgg16_from_caffe.pth')
        model = models.fcn32s(class_number=2, transition=True)
        if resume:
            checkpoint = torch.load(cfg.best_ckpt_dir)
            model.load_state_dict(checkpoint['state_dict'])
            print('resume success')
        else:
            vgg_pretrain_model = util.load_pretrain_model(pretrain_vgg_path)
            model.init_parameters(vgg_pretrain_model)

    model = model.cuda()
    MaskLoss = ls.ConstractiveMaskLoss(thresh_flag=True)
    save_training_weights_dir = os.path.join(cfg.SAVE_PRED_PATH,
                                             'weights_visual')
    check_dir(save_training_weights_dir)
    save_spatial_att_dir = os.path.join(cfg.SAVE_PRED_PATH,
                                        'various_spatial_att/')
    check_dir(save_spatial_att_dir)
    save_valid_dir = os.path.join(cfg.SAVE_PRED_PATH, 'various_valid_imgs')
    check_dir(save_valid_dir)
    save_roc_dir = os.path.join(cfg.SAVE_PRED_PATH, 'ROC')
    check_dir(save_roc_dir)

    #########
    ######### optimizer ##########
    ######## how to set different learning rate for differern layer #########
    optimizer = torch.optim.SGD(
        [{
            'params': set_base_learning_rate_for_multi_layer(model),
            'lr': cfg.INIT_LEARNING_RATE
        }, {
            'params': set_2x_learning_rate_for_multi_layer(model),
            'lr': 2 * cfg.INIT_LEARNING_RATE,
            'weight_decay': 0
        }, {
            'params': set_10x_learning_rate_for_multi_layer(model),
            'lr': 10 * cfg.INIT_LEARNING_RATE
        }, {
            'params': set_20x_learning_rate_for_multi_layer(model),
            'lr': 20 * cfg.INIT_LEARNING_RATE,
            'weight_decay': 0
        }],
        lr=cfg.INIT_LEARNING_RATE,
        momentum=cfg.MOMENTUM,
        weight_decay=cfg.DECAY)

    ######## iter img_label pairs ###########

    loss_total = 0
    for epoch in range(100):
        for batch_idx, batch in enumerate(train_loader):

            step = epoch * len(train_loader) + batch_idx
            util.adjust_learning_rate(cfg.INIT_LEARNING_RATE, optimizer, step)
            model.train()
            img1_idx, img2_idx, label_idx, filename, height, width = batch
            img1, img2, label = Variable(img1_idx.cuda()), Variable(
                img2_idx.cuda()), Variable(label_idx.cuda())
            seg_pred, out_conv5, out_fc = model(img1, img2)
            out_conv5_t0, out_conv5_t1 = out_conv5
            out_fc_t0, out_fc_t1 = out_fc
            label_rz_conv5 = rz.resize_label(
                label.data.cpu().numpy(),
                size=out_conv5_t0.data.cpu().numpy().shape[2:])
            label_rz_fc = rz.resize_label(
                label.data.cpu().numpy(),
                size=out_fc_t0.data.cpu().numpy().shape[2:])
            label_rz_conv5 = Variable(label_rz_conv5.cuda())
            label_rz_fc = Variable(label_rz_fc.cuda())
            seg_loss = util.cross_entropy2d(seg_pred,
                                            label,
                                            size_average=False)
            constractive_loss_conv5 = MaskLoss(out_conv5_t0, out_conv5_t1,
                                               label_rz_conv5)
            constractive_loss_fc = MaskLoss(out_fc_t0, out_fc_t1, label_rz_fc)
            #constractive_loss = MaskLoss(out_conv5_t0,out_conv5_t1,label_rz_conv5) + \
            #MaskLoss(out_fc_t0,out_fc_t1,label_rz_fc)
            loss = seg_loss + cfg.LOSS_PARAM_CONV * constractive_loss_conv5 + cfg.LOSS_PARAM_FC * constractive_loss_fc
            loss_total += loss.data.cpu()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (batch_idx) % 20 == 0:
                #print("Epoch [%d/%d] Loss: %.4f" % (epoch, batch_idx, loss.data[0]))
                print(
                    "Epoch [%d/%d] Loss: %.4f Seg_Loss: %.4f Mask_Loss_conv5: %.4f Mask_Loss_fc: %.4f"
                    % (epoch, batch_idx, loss.data[0], seg_loss.data[0],
                       constractive_loss_conv5.data[0],
                       constractive_loss_fc.data[0]))
            if (batch_idx) % 1000 == 0:
                model.eval()
                current_metric = validate(model, val_loader, epoch,
                                          save_valid_dir, save_spatial_att_dir,
                                          save_roc_dir)
                if current_metric > best_metric:
                    torch.save({'state_dict': model.state_dict()},
                               os.path.join(cfg.SAVE_CKPT_PATH,
                                            'model' + str(epoch) + '.pth'))
                    shutil.copy(
                        os.path.join(cfg.SAVE_CKPT_PATH,
                                     'model' + str(epoch) + '.pth'),
                        os.path.join(cfg.SAVE_CKPT_PATH, 'model_best.pth'))
                    best_metric = current_metric
        current_metric = validate(model, val_loader, epoch, save_valid_dir,
                                  save_spatial_att_dir, save_roc_dir)
        if current_metric > best_metric:
            torch.save({'state_dict': model.state_dict()},
                       os.path.join(cfg.SAVE_CKPT_PATH,
                                    'model' + str(epoch) + '.pth'))
            shutil.copy(
                os.path.join(cfg.SAVE_CKPT_PATH,
                             'model' + str(epoch) + '.pth'),
                os.path.join(cfg.SAVE_CKPT_PATH, 'model_best.pth'))
            best_metric = current_metric
        if epoch % 5 == 0:
            torch.save({'state_dict': model.state_dict()},
                       os.path.join(cfg.SAVE_CKPT_PATH,
                                    'model' + str(epoch) + '.pth'))
Exemplo n.º 5
0
def test_main():

    #########  configs ###########

    val_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])

    TEST_TXT_PATH = '/home/lorant/Projects/data/SceneChangeDet/cd2014/test.txt'

    val_data = dates.Dataset(cfg.TRAIN_DATA_PATH,
                             cfg.TRAIN_LABEL_PATH,
                             cfg.TRAIN_TXT_PATH,
                             'val',
                             transform=True,
                             transform_med=val_transform_det)

    val_data = dates.Dataset(cfg.VAL_DATA_PATH,
                             cfg.VAL_LABEL_PATH,
                             cfg.VAL_TXT_PATH,
                             'val',
                             transform=True,
                             transform_med=val_transform_det)

    val_data = dates.Dataset(cfg.TEST_DATA_PATH,
                             cfg.TEST_LABEL_PATH,
                             cfg.TEST_TXT_PATH,
                             'val',
                             transform=True,
                             transform_med=val_transform_det)

    val_loader = Data.DataLoader(val_data,
                                 batch_size=cfg.BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=4,
                                 pin_memory=True)
    t1 = datetime.now()
    print("Model loading start at " + str(t1))
    ######  build  models ########
    base_seg_model = 'deeplab'
    if base_seg_model == 'deeplab':
        import model.siameseNet.deeplab_v2 as models
        pretrain_deeplab_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                             'deeplab_v2_voc12.pth')
        model = models.SiameseNet(norm_flag='l2')
        if resume:
            checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT)
            model.load_state_dict(checkpoint['state_dict'])
            print('resume success')
        else:
            deeplab_pretrain_model = torch.load(pretrain_deeplab_path)
            model.init_parameters_from_deeplab(deeplab_pretrain_model)
    else:
        import model.siameseNet.fcn32s_tiny as models
        pretrain_vgg_path = os.path.join(cfg.PRETRAIN_MODEL_PATH,
                                         'vgg16_from_caffe.pth')
        model = models.SiameseNet(distance_flag='softmax')
        if resume:
            checkpoint = torch.load(cfg.TRAINED_BEST_PERFORMANCE_CKPT)
            model.load_state_dict(checkpoint['state_dict'])
            print('resume success')
        else:
            vgg_pretrain_model = util.load_pretrain_model(pretrain_vgg_path)
            model.init_parameters(vgg_pretrain_model)

    model = model.cuda()
    print("Model loaded in " + str(datetime.now() - t1))
    MaskLoss = ls.ConstractiveMaskLoss()
    ab_test_dir = os.path.join(cfg.SAVE_PRED_PATH, 'contrastive_loss')
    check_dir(ab_test_dir)
    save_change_map_dir = os.path.join(ab_test_dir, 'changemaps/')
    save_valid_dir = os.path.join(ab_test_dir, 'valid_imgs')
    save_roc_dir = os.path.join(ab_test_dir, 'roc')
    check_dir(save_change_map_dir), check_dir(save_valid_dir), check_dir(
        save_roc_dir)
    #########
    ######### optimizer ##########
    ######## how to set different learning rate for differernt layers #########

    ######## iter img_label pairs ###########
    loss_total = 0
    for epoch in range(1):
        current_metric = validate(model, val_loader, epoch,
                                  save_change_map_dir, save_roc_dir)