Beispiel #1
0
def load_model(model_dir):
    net = BASNet(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()
    return net
Beispiel #2
0
def main():
    image_dir = '/media/markytools/New Volume/Courses/EE298CompVis/finalproject/datasets/DUTS/DUTS-TE/DUTS-TE-Image/'
    prediction_dir = './predictionout/'
    model_dir = './saved_models/basnet_bsi_dataaugandarchi/basnet_bsi_epoch_207_itr_272000_train_9.345837_tar_1.082428.pth'

    img_name_list = glob.glob(image_dir + '*.jpg')

    # --------- 2. dataloader ---------
    #1. dataload
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(256),
                                      ToTensorLab(flag=0)]),
        category="test")
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    print("...load BASNet...")
    net = BASNet(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split("/")[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        # d1,d2,d3,d4,d5,d6,d7,d8 = net(inputs_test)
        d1, d2, d3, d4, d5, d6, d7, d8, d1_struct, d2_struct, d3_struct, d4_struct, d5_struct, d6_struct, d7_struct = net(
            inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7, d8
Beispiel #3
0
def load_BASNet():
    global net, BASNet_loaded
    
    if BASNet_loaded:
        print('BASNet is already loaded')
    else:
        print("Loading BASNet...")
        net = BASNet(3, 1)
        net.load_state_dict(torch.load(model_dir))
        if torch.cuda.is_available():
            net.cuda()
        net.eval()
Beispiel #4
0
def run_prediction(files):
    img_name_list = files

    # --------- 2. dataloader ---------
    #1. dataload
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(256),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    print("...load BASNet...")
    net = BASNet(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split("/")[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7, d8 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7, d8
Beispiel #5
0
img_name_list = glob.glob(image_dir + '*.jpg')

# --------- 2. dataloader ---------
# 1. dataload
test_salobj_dataset = SalObjDataset(img_name_list=img_name_list, lbl_name_list=[
], transform=transforms.Compose([RescaleT(256), ToTensorLab(flag=0)]))
test_salobj_dataloader = DataLoader(
    test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1)

# --------- 3. model define ---------
print("...load BASNet...")
net = BASNet(3, 1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
    net.cuda()
net.eval()

# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):

    print("inferencing:", img_name_list[i_test].split("/")[-1])

    inputs_test = data_test['image']
    inputs_test = inputs_test.type(torch.FloatTensor)

    if torch.cuda.is_available():
        inputs_test = Variable(inputs_test.cuda())
    else:
        inputs_test = Variable(inputs_test)
def train():

    data_dir = './train_data/'
    tra_image_dir = 'DUTS-TR/DUTS-TR-Image/'
    tra_label_dir = 'DUTS-TR/DUTS-TR-Mask/'

    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = "./saved_models/basnet_bsi/"

    epoch_num = 100
    batch_size_train = 1
    batch_size_val = 1
    train_num = 0
    val_num = 0

    tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
    tra_lbl_name_list = glob.glob(data_dir + tra_label_dir + '*' + label_ext)
    #	tra_lbl_name_list = []
    #	for img_path in tra_img_name_list:
    #		img_name = img_path.split("/")[-1]

    #		aaa = img_name.split(".")
    #		print(aaa)
    #		bbb = aaa[0:-1]
    #		imidx = bbb[0]
    #		for i in range(1,len(bbb)):
    #			imidx = imidx + "." + bbb[i]
    #			print(imidx)
    #		tra_lbl_name_list.append(data_dir + tra_label_dir + '*' + label_ext)

    print("---")
    print("train images: ", len(tra_img_name_list))
    print("train labels: ", len(tra_lbl_name_list))
    print("---")

    train_num = len(tra_img_name_list)

    salobj_dataset = SalObjDataset(img_name_list=tra_img_name_list,
                                   lbl_name_list=tra_lbl_name_list,
                                   transform=transforms.Compose([
                                       RescaleT(256),
                                       RandomCrop(224),
                                       ToTensorLab(flag=0)
                                   ]))
    salobj_dataloader = DataLoader(salobj_dataset,
                                   batch_size=batch_size_train,
                                   shuffle=True,
                                   num_workers=1)

    # ------- 3. define model --------
    # define the net
    net = BASNet(3, 1)
    #if torch.cuda.is_available():
    #     print(torch.cuda.is_available())
    net.cuda()

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(),
                           lr=0.001,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)

    # ------- 5. training process --------
    print("---start training...")
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 0

    for epoch in range(0, epoch_num):
        net.train()

        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1

            inputs, labels = data['image'], data['label']

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)

            # wrap them in Variable
            #		if torch.cuda.is_available():
            inputs_v, labels_v = Variable(inputs.cuda(),
                                          requires_grad=False), Variable(
                                              labels.cuda(),
                                              requires_grad=False)
            #	else:
            #		inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

            # y zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6, d7 = net(inputs_v)
            loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7,
                                               labels_v)

            loss.backward()
            optimizer.step()

            # # print statistics
            running_loss += loss.data
            running_tar_loss += loss2.data

            # del temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, d7, loss2, loss

            print(
                "[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f "
                % (epoch + 1, epoch_num,
                   (i + 1) * batch_size_train, train_num, ite_num,
                   running_loss / ite_num4val, running_tar_loss / ite_num4val))

            if ite_num % 2000 == 0:  # save model every 2000 iterations

                torch.save(
                    net.state_dict(),
                    model_dir + "basnet_bsi_itr_%d_train_%3f_tar_%3f.pth" %
                    (ite_num, running_loss / ite_num4val,
                     running_tar_loss / ite_num4val))
                running_loss = 0.0
                running_tar_loss = 0.0
                net.train()  # resume train
                ite_num4val = 0

    print('-------------Congratulations! Training Done!!!-------------')