Exemple #1
0
def load_model(model_dir):
    net = BASNet(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()
    return net
Exemple #2
0
def main():
    image_dir = '/media/markytools/New Volume/Courses/EE298CompVis/finalproject/datasets/DUTS/DUTS-TE/DUTS-TE-Image/'
    prediction_dir = './predictionout/'
    model_dir = './saved_models/basnet_bsi_dataaugandarchi/basnet_bsi_epoch_207_itr_272000_train_9.345837_tar_1.082428.pth'

    img_name_list = glob.glob(image_dir + '*.jpg')

    # --------- 2. dataloader ---------
    #1. dataload
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(256),
                                      ToTensorLab(flag=0)]),
        category="test")
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    print("...load BASNet...")
    net = BASNet(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split("/")[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        # d1,d2,d3,d4,d5,d6,d7,d8 = net(inputs_test)
        d1, d2, d3, d4, d5, d6, d7, d8, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct = net(
            inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7, d8
Exemple #3
0
def load_BASNet():
    global net, BASNet_loaded
    
    if BASNet_loaded:
        print('BASNet is already loaded')
    else:
        print("Loading BASNet...")
        net = BASNet(3, 1)
        net.load_state_dict(torch.load(model_dir))
        if torch.cuda.is_available():
            net.cuda()
        net.eval()
def run_prediction(files):
    img_name_list = files

    # --------- 2. dataloader ---------
    #1. dataload
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(256),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    print("...load BASNet...")
    net = BASNet(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split("/")[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7, d8 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7, d8
Exemple #5
0
img_name_list = glob.glob(image_dir + '*.jpg')

# --------- 2. dataloader ---------
# 1. dataload
test_salobj_dataset = SalObjDataset(img_name_list=img_name_list, lbl_name_list=[
], transform=transforms.Compose([RescaleT(256), ToTensorLab(flag=0)]))
test_salobj_dataloader = DataLoader(
    test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1)

# --------- 3. model define ---------
print("...load BASNet...")
net = BASNet(3, 1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
    net.cuda()
net.eval()

# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):

    print("inferencing:", img_name_list[i_test].split("/")[-1])

    inputs_test = data_test['image']
    inputs_test = inputs_test.type(torch.FloatTensor)

    if torch.cuda.is_available():
        inputs_test = Variable(inputs_test.cuda())
    else:
        inputs_test = Variable(inputs_test)

    d1, d2, d3, d4, d5, d6, d7, d8 = net(inputs_test)
Exemple #6
0
def train():
    if os.name == 'nt':
        data_dir = 'C:/Users/marky/Documents/Courses/saliency/datasets/DUTS/'
    else:
        data_dir = os.getenv(
            "HOME") + '/Documents/Courses/EE298-CV/finalproj/datasets/DUTS/'
    tra_image_dir = 'DUTS-TR/DUTS-TR-Image/'
    tra_label_dir = 'DUTS-TR/DUTS-TR-Mask/'
    test_image_dir = 'DUTS-TE/DUTS-TE-Image/'
    test_label_dir = 'DUTS-TE/DUTS-TE-Mask/'

    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = "./saved_models/basnet_bsi_aug/"
    resume_train = False
    resume_model_path = model_dir + "basnet_bsi_epoch_81_itr_106839_train_1.511335_tar_0.098392.pth"
    last_epoch = 1
    epoch_num = 100000
    batch_size_train = 8
    batch_size_val = 1
    train_num = 0
    val_num = 0
    enableInpaintAug = False
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # ------- 5. training process --------
    print("---start training...")
    test_increments = 6250
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 1
    next_test = ite_num + 0
    visdom_tab_title = "StructArchWithoutStructImgs(WithHFlip)"
    ############
    ############
    ############
    ############

    tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
    print("data_dir + tra_image_dir + '*' + image_ext: ",
          data_dir + tra_image_dir + '*' + image_ext)
    test_img_name_list = glob.glob(data_dir + test_image_dir + '*' + image_ext)
    print("data_dir + test_image_dir + '*' + image_ext: ",
          data_dir + test_image_dir + '*' + image_ext)

    tra_lbl_name_list = []
    for img_path in tra_img_name_list:
        img_name = img_path.split("/")[-1]
        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1, len(bbb)):
            imidx = imidx + "." + bbb[i]
        tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
    test_lbl_name_list = []
    for img_path in test_img_name_list:
        img_name = img_path.split("/")[-1]
        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1, len(bbb)):
            imidx = imidx + "." + bbb[i]
        test_lbl_name_list.append(data_dir + test_label_dir + imidx +
                                  label_ext)

    print("---")
    print("train images: ", len(tra_img_name_list))
    print("train labels: ", len(tra_lbl_name_list))
    print("---")

    print("---")
    print("test images: ", len(test_img_name_list))
    print("test labels: ", len(test_lbl_name_list))
    print("---")

    train_num = len(tra_img_name_list)
    test_num = len(test_img_name_list)
    salobj_dataset = SalObjDataset(img_name_list=tra_img_name_list,
                                   lbl_name_list=tra_lbl_name_list,
                                   transform=transforms.Compose([
                                       RescaleT(256),
                                       RandomCrop(224),
                                       ToTensorLab(flag=0)
                                   ]),
                                   category="train",
                                   enableInpaintAug=enableInpaintAug)
    salobj_dataset_test = SalObjDataset(img_name_list=test_img_name_list,
                                        lbl_name_list=test_lbl_name_list,
                                        transform=transforms.Compose([
                                            RescaleT(256),
                                            RandomCrop(224),
                                            ToTensorLab(flag=0)
                                        ]),
                                        category="test",
                                        enableInpaintAug=enableInpaintAug)
    salobj_dataloader = DataLoader(salobj_dataset,
                                   batch_size=batch_size_train,
                                   shuffle=True,
                                   num_workers=1)
    salobj_dataloader_test = DataLoader(salobj_dataset_test,
                                        batch_size=batch_size_val,
                                        shuffle=True,
                                        num_workers=1)

    # ------- 3. define model --------
    # define the net
    net = BASNet(3, 1)
    if resume_train:
        # print("resume_model_path:", resume_model_path)
        checkpoint = torch.load(resume_model_path)
        net.load_state_dict(checkpoint)
    if torch.cuda.is_available():
        net.to(device)

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(),
                           lr=0.001,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)

    plotter = VisdomLinePlotter(env_name=visdom_tab_title)

    best_ave_mae = 100000
    best_max_fmeasure = 0
    best_relaxed_fmeasure = 0
    best_ave_maxf = 0
    best_own_RelaxedFmeasure = 0
    for epoch in range(last_epoch - 1, epoch_num):
        ### Train network
        train_loss0 = AverageMeter()
        train_loss1 = AverageMeter()
        train_loss2 = AverageMeter()
        train_loss3 = AverageMeter()
        train_loss4 = AverageMeter()
        train_loss5 = AverageMeter()
        train_loss6 = AverageMeter()
        train_loss7 = AverageMeter()
        train_struct_loss1 = AverageMeter()
        train_struct_loss2 = AverageMeter()
        train_struct_loss3 = AverageMeter()
        train_struct_loss4 = AverageMeter()
        train_struct_loss5 = AverageMeter()
        train_struct_loss6 = AverageMeter()
        train_struct_loss7 = AverageMeter()

        test_loss0 = AverageMeter()
        test_loss1 = AverageMeter()
        test_loss2 = AverageMeter()
        test_loss3 = AverageMeter()
        test_loss4 = AverageMeter()
        test_loss5 = AverageMeter()
        test_loss6 = AverageMeter()
        test_loss7 = AverageMeter()
        test_struct_loss1 = AverageMeter()
        test_struct_loss2 = AverageMeter()
        test_struct_loss3 = AverageMeter()
        test_struct_loss4 = AverageMeter()
        test_struct_loss5 = AverageMeter()
        test_struct_loss6 = AverageMeter()
        test_struct_loss7 = AverageMeter()

        average_mae = AverageMeter()
        average_maxf = AverageMeter()
        average_relaxedf = AverageMeter()
        average_own_RelaxedFMeasure = AverageMeter()
        net.train()
        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1
            inputs, labels, labels_struct = data['image'], data['label'], data[
                'label2']

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)
            labels_struct = labels_struct.type(torch.FloatTensor)

            # wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v, labels_struct_v = Variable(
                    inputs.to(device), requires_grad=False), Variable(
                        labels.to(device), requires_grad=False), Variable(
                            labels_struct.to(device), requires_grad=False)
            else:
                inputs_v, labels_v, labels_struct_v = Variable(
                    inputs, requires_grad=False), Variable(
                        labels,
                        requires_grad=False), Variable(labels_struct,
                                                       requires_grad=False)

            # y zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct = net(
                inputs_v)
            loss2, loss = muti_bce_loss_fusion(
                d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct,
                d3_struct, d4_struct, d5_struct, d6_struct, d7_struct,
                labels_v, train_loss0, train_loss1, train_loss2, train_loss3,
                train_loss4, train_loss5, train_loss6, train_loss7,
                train_struct_loss1, train_struct_loss2, train_struct_loss3,
                train_struct_loss4, train_struct_loss5, train_struct_loss6,
                train_struct_loss7)
            loss.backward()
            optimizer.step()

            # # print statistics
            running_loss += loss.data
            running_tar_loss += loss2.data

            # del temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct, loss2, loss

            print(
                "[train epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f "
                % (epoch + 1, epoch_num,
                   (i + 1) * batch_size_train, train_num, ite_num,
                   running_loss / ite_num4val, running_tar_loss / ite_num4val))
        plotter.plot('loss0', 'train', 'Main Loss 0', epoch + 1,
                     float(train_loss0.avg))
        plotter.plot('loss1', 'train', 'Main Loss 1', epoch + 1,
                     float(train_loss1.avg))
        plotter.plot('loss2', 'train', 'Main Loss 2', epoch + 1,
                     float(train_loss2.avg))
        plotter.plot('loss3', 'train', 'Main Loss 3', epoch + 1,
                     float(train_loss3.avg))
        plotter.plot('loss4', 'train', 'Main Loss 4', epoch + 1,
                     float(train_loss4.avg))
        plotter.plot('loss5', 'train', 'Main Loss 5', epoch + 1,
                     float(train_loss5.avg))
        plotter.plot('loss6', 'train', 'Main Loss 6', epoch + 1,
                     float(train_loss6.avg))
        plotter.plot('loss7', 'train', 'Main Loss 7', epoch + 1,
                     float(train_loss7.avg))
        plotter.plot('structloss1', 'train', 'Struct Loss 1', epoch + 1,
                     float(train_struct_loss1.avg))
        plotter.plot('structloss2', 'train', 'Struct Loss 2', epoch + 1,
                     float(train_struct_loss2.avg))
        plotter.plot('structloss3', 'train', 'Struct Loss 3', epoch + 1,
                     float(train_struct_loss3.avg))
        plotter.plot('structloss4', 'train', 'Struct Loss 4', epoch + 1,
                     float(train_struct_loss4.avg))
        plotter.plot('structloss5', 'train', 'Struct Loss 5', epoch + 1,
                     float(train_struct_loss5.avg))
        plotter.plot('structloss6', 'train', 'Struct Loss 6', epoch + 1,
                     float(train_struct_loss6.avg))
        plotter.plot('structloss7', 'train', 'Struct Loss 7', epoch + 1,
                     float(train_struct_loss7.avg))

        ### Validate model
        print("---Evaluate model---")
        if ite_num >= next_test:  # test and save model 10000 iterations, due to very large DUTS-TE dataset
            next_test = ite_num + test_increments
            net.eval()
            max_epoch_fmeasure = 0
            for i, data in enumerate(salobj_dataloader_test):
                inputs, labels = data['image'], data['label']
                inputs = inputs.type(torch.FloatTensor)
                labels = labels.type(torch.FloatTensor)
                if torch.cuda.is_available():
                    inputs_v, labels_v = Variable(
                        inputs.to(device),
                        requires_grad=False), Variable(labels.to(device),
                                                       requires_grad=False)
                else:
                    inputs_v, labels_v = Variable(
                        inputs,
                        requires_grad=False), Variable(labels,
                                                       requires_grad=False)
                d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct = net(
                    inputs_v)

                pred = d0[:, 0, :, :]
                pred = normPRED(pred)
                pred = pred.squeeze()
                predict_np = pred.cpu().data.numpy()
                im = Image.fromarray(predict_np * 255).convert('RGB')
                img_name = test_img_name_list[i]
                image = cv2.imread(img_name)
                imo = im.resize((image.shape[1], image.shape[0]),
                                resample=Image.BILINEAR)
                imo = imo.convert("L")  ###  Convert to grayscale 1-channel
                resizedImg_np = np.array(
                    imo)  ### Result is 2D numpy array predicted salient map
                img__lbl_name = test_lbl_name_list[i]
                gt_img = np.array(Image.open(img__lbl_name).convert(
                    "L"))  ### Ground truth salient map

                ### Compute metrics
                result_mae = getMAE(gt_img, resizedImg_np)
                average_mae.update(result_mae, 1)
                precision, recall = getPRCurve(gt_img, resizedImg_np)
                result_maxfmeasure = getMaxFMeasure(precision, recall)
                result_maxfmeasure = result_maxfmeasure.mean()
                average_maxf.update(result_maxfmeasure, 1)
                if (result_maxfmeasure > max_epoch_fmeasure):
                    max_epoch_fmeasure = result_maxfmeasure
                result_relaxedfmeasure = getRelaxedFMeasure(
                    gt_img, resizedImg_np)
                result_ownrelaxedfmeasure = own_RelaxedFMeasure(
                    gt_img, resizedImg_np)
                average_relaxedf.update(result_relaxedfmeasure, 1)
                average_own_RelaxedFMeasure.update(result_ownrelaxedfmeasure,
                                                   1)
                loss2, loss = muti_bce_loss_fusion(
                    d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct,
                    d3_struct, d4_struct, d5_struct, d6_struct, d7_struct,
                    labels_v, test_loss0, test_loss1, test_loss2, test_loss3,
                    test_loss4, test_loss5, test_loss6, test_loss7,
                    test_struct_loss1, test_struct_loss2, test_struct_loss3,
                    test_struct_loss4, test_struct_loss5, test_struct_loss6,
                    test_struct_loss7)
                del d0, d1, d2, d3, d4, d5, d6, d7, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct, loss2, loss
                print(
                    "[test epoch: %3d/%3d, batch: %5d/%5d, ite: %d] test loss: %3f, tar: %3f "
                    % (epoch + 1, epoch_num, (i + 1) * batch_size_val,
                       test_num, ite_num, running_loss / ite_num4val,
                       running_tar_loss / ite_num4val))
            model_name = model_dir + "basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f.pth" % (
                epoch + 1, ite_num, running_loss / ite_num4val,
                running_tar_loss / ite_num4val)
            torch.save(net.state_dict(), model_name)
            running_loss = 0.0
            running_tar_loss = 0.0
            net.train()  # resume train
            ite_num4val = 1
            if (average_mae.avg < best_ave_mae):
                best_ave_mae = average_mae.avg
                newname = model_dir + "bestMAE/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_mae_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_ave_mae)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (max_epoch_fmeasure > best_max_fmeasure):
                best_max_fmeasure = max_epoch_fmeasure
                newname = model_dir + "bestEpochMaxF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_maxfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_max_fmeasure)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (average_maxf.avg > best_ave_maxf):
                best_ave_maxf = average_maxf.avg
                newname = model_dir + "bestAveMaxF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_avemfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_ave_maxf)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (average_relaxedf.avg > best_relaxed_fmeasure):
                best_relaxed_fmeasure = average_relaxedf.avg
                newname = model_dir + "bestAveRelaxF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_averelaxfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_relaxed_fmeasure)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            if (average_own_RelaxedFMeasure.avg > best_own_RelaxedFmeasure):
                best_own_RelaxedFmeasure = average_own_RelaxedFMeasure.avg
                newname = model_dir + "bestOwnRelaxedF/basnet_bsi_epoch_%d_itr_%d_train_%3f_tar_%3f_averelaxfmeas_%3f.pth" % (
                    epoch + 1, ite_num, running_loss / ite_num4val,
                    running_tar_loss / ite_num4val, best_own_RelaxedFmeasure)
                fold_dir = newname.rsplit("/", 1)
                if not os.path.isdir(fold_dir[0]): os.mkdir(fold_dir[0])
                copyfile(model_name, newname)
            plotter.plot('loss0', 'test', 'Main Loss 0', epoch + 1,
                         float(test_loss0.avg))
            plotter.plot('loss1', 'test', 'Main Loss 1', epoch + 1,
                         float(test_loss1.avg))
            plotter.plot('loss2', 'test', 'Main Loss 2', epoch + 1,
                         float(test_loss2.avg))
            plotter.plot('loss3', 'test', 'Main Loss 3', epoch + 1,
                         float(test_loss3.avg))
            plotter.plot('loss4', 'test', 'Main Loss 4', epoch + 1,
                         float(test_loss4.avg))
            plotter.plot('loss5', 'test', 'Main Loss 5', epoch + 1,
                         float(test_loss5.avg))
            plotter.plot('loss6', 'test', 'Main Loss 6', epoch + 1,
                         float(test_loss6.avg))
            plotter.plot('loss7', 'test', 'Main Loss 7', epoch + 1,
                         float(test_loss7.avg))
            plotter.plot('structloss1', 'test', 'Struct Loss 1', epoch + 1,
                         float(test_struct_loss1.avg))
            plotter.plot('structloss2', 'test', 'Struct Loss 2', epoch + 1,
                         float(test_struct_loss2.avg))
            plotter.plot('structloss3', 'test', 'Struct Loss 3', epoch + 1,
                         float(test_struct_loss3.avg))
            plotter.plot('structloss4', 'test', 'Struct Loss 4', epoch + 1,
                         float(test_struct_loss4.avg))
            plotter.plot('structloss5', 'test', 'Struct Loss 5', epoch + 1,
                         float(test_struct_loss5.avg))
            plotter.plot('structloss6', 'test', 'Struct Loss 6', epoch + 1,
                         float(test_struct_loss6.avg))
            plotter.plot('structloss7', 'test', 'Struct Loss 7', epoch + 1,
                         float(test_struct_loss7.avg))
            plotter.plot('mae', 'test', 'Average Epoch MAE', epoch + 1,
                         float(average_mae.avg))
            plotter.plot('max_maxf', 'test', 'Max Max Epoch F-Measure',
                         epoch + 1, float(max_epoch_fmeasure))
            plotter.plot('ave_maxf', 'test', 'Average Max F-Measure',
                         epoch + 1, float(average_maxf.avg))
            plotter.plot('ave_relaxedf', 'test', 'Average Relaxed F-Measure',
                         epoch + 1, float(average_relaxedf.avg))
            plotter.plot('own_RelaxedFMeasure', 'test',
                         'Own Average Relaxed F-Measure', epoch + 1,
                         float(average_own_RelaxedFMeasure.avg))
    print('-------------Congratulations! Training Done!!!-------------')
def train():
    data_dir = '/media/markytools/New Volume/Courses/EE298CompVis/finalproject/datasets/'
    test_image_dir = 'DUTS/DUTS-TE/DUTS-TE-Image/'
    test_label_dir = 'DUTS/DUTS-TE/DUTS-TE-Mask/'
    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = "../saved_models/"
    resume_train = True
    resume_model_path = model_dir + "basnet-original.pth"
    last_epoch = 1
    epoch_num = 100000
    batch_size_train = 8
    batch_size_val = 1
    train_num = 0
    val_num = 0
    enableInpaintAug = False
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") #set CPU to 0
    # ------- 5. training process --------
    print("---start training...")
    test_increments = 15000
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 1
    next_test = ite_num + 0
    ############
    ############
    ############
    ############
    test_img_name_list = glob.glob(data_dir + test_image_dir + '*' + image_ext)
    print("data_dir + test_image_dir + '*' + image_ext: ", data_dir + test_image_dir + '*' + image_ext)

    test_lbl_name_list = []
    for img_path in test_img_name_list:
    	img_name = img_path.split("/")[-1]
    	aaa = img_name.split(".")
    	bbb = aaa[0:-1]
    	imidx = bbb[0]
    	for i in range(1,len(bbb)):
    		imidx = imidx + "." + bbb[i]
    	test_lbl_name_list.append(data_dir + test_label_dir + imidx + label_ext)

    print("---")
    print("test images: ", len(test_img_name_list))
    print("test labels: ", len(test_lbl_name_list))
    print("---")

    test_num = len(test_img_name_list)
    for test_lbl in test_lbl_name_list:
        test_jpg = test_lbl.replace("png", "jpg")
        test_jpg = test_jpg.replace("Mask", "Image")
        if test_jpg not in test_img_name_list: print("test_lbl not in label: ", test_lbl)

    salobj_dataset_test = SalObjDataset(
        img_name_list=test_img_name_list,
        lbl_name_list=test_lbl_name_list,
        transform=transforms.Compose([
            RescaleT(256),
            RandomCrop(224),
            ToTensorLab(flag=0)]),
    		category="test",
    		enableInpaintAug=enableInpaintAug)
    salobj_dataloader_test = DataLoader(salobj_dataset_test, batch_size=batch_size_val, shuffle=True, num_workers=1)

    # ------- 3. define model --------
    # define the net
    net = BASNet(3, 1)
    if resume_train:
    	# print("resume_model_path:", resume_model_path)
    	checkpoint = torch.load(resume_model_path)
    	net.load_state_dict(checkpoint)
    if torch.cuda.is_available():
        net.to(device)

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    plotter = VisdomLinePlotter(env_name='NewlyAddedRelaxedMeasureEnv1')

    best_ave_mae = 100000
    best_max_fmeasure = 0
    best_relaxed_fmeasure = 0
    best_ave_maxf = 0

    ### Train network
    train_loss0 = AverageMeter()
    train_loss1 = AverageMeter()
    train_loss2 = AverageMeter()
    train_loss3 = AverageMeter()
    train_loss4 = AverageMeter()
    train_loss5 = AverageMeter()
    train_loss6 = AverageMeter()
    train_loss7 = AverageMeter()


    test_loss0 = AverageMeter()
    test_loss1 = AverageMeter()
    test_loss2 = AverageMeter()
    test_loss3 = AverageMeter()
    test_loss4 = AverageMeter()
    test_loss5 = AverageMeter()
    test_loss6 = AverageMeter()
    test_loss7 = AverageMeter()

    average_mae = AverageMeter()
    average_maxf = AverageMeter()
    average_relaxedf = AverageMeter()
    ### Validate model
    print("---Evaluate model---")
    next_test = ite_num + test_increments
    net.eval()
    max_epoch_fmeasure = 0
    for i, data in enumerate(salobj_dataloader_test):
        inputs, labels = data['image'], data['label']
        inputs = inputs.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)
        if torch.cuda.is_available():
            inputs_v, labels_v = Variable(inputs.to(device), requires_grad=False), Variable(labels.to(device), requires_grad=False)
        else:
            inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
        d0, d1, d2, d3, d4, d5, d6, d7 = net(inputs_v)
        pred = d0[:,0,:,:]
        pred = normPRED(pred)
        pred = pred.squeeze()
        predict_np = pred.cpu().data.numpy()
        im = Image.fromarray(predict_np*255).convert('RGB')
        img_name = test_img_name_list[i]
        image = cv2.imread(img_name)
        imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
        imo = imo.convert("L") ###  Convert to grayscale 1-channel
        resizedImg_np = np.array(imo) ### Result is 2D numpy array predicted salient map
        img__lbl_name = test_lbl_name_list[i]
        gt_img = np.array(Image.open(img__lbl_name).convert("L")) ### Ground truth salient map

        ### Compute metrics
        img_name_png =
        result_mae = getMAE(gt_img, resizedImg_np)
        average_mae.update(result_mae, 1)
        precision, recall = getPRCurve(gt_img, resizedImg_np)
        result_maxfmeasure = getMaxFMeasure(precision, recall)
        result_maxfmeasure = result_maxfmeasure.mean()
        average_maxf.update(result_maxfmeasure, 1)
        if (result_maxfmeasure > max_epoch_fmeasure):
        	max_epoch_fmeasure = result_maxfmeasure
        result_relaxedfmeasure = getRelaxedFMeasure(gt_img, resizedImg_np)
        average_relaxedf.update(result_relaxedfmeasure, 1)
        loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7,labels_v,
        						test_loss0,
        						test_loss1,
        						test_loss2,
        						test_loss3,
        						test_loss4,
        						test_loss5,
        						test_loss6,
        						test_loss7)
        del d0, d1, d2, d3, d4, d5, d6, d7,loss2, loss
    print("Average Epoch MAE: ", average_mae.avg)
    print("Max Max Epoch F-Measure: ", average_maxf.avg)
    print("Average Max F-Measure: ", max_epoch_fmeasure)
    print("Average Relaxed F-Measure: ", average_relaxedf.avg)

    print('-------------Congratulations! Training Done!!!-------------')