Пример #1
0
    def __init__(self, model_name='u2net', cuda_mode=True, output_format='np'):
        self.model_name = model_name
        self.cuda_mode = cuda_mode and torch.cuda.is_available()  # Fallback to CPU mode, if cuda is not available
        self.trans = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
        self.output_format = output_format

        # Validate
        if output_format not in self.FORMATS:
            raise AssertionError('Invalid "output_format"', 'Use "np" or "pil"')
        if model_name not in self.MODEL_NAMES:
            raise AssertionError('Invalid "model_name"', 'Use "u2net" or "u2netp"')

        if model_name == 'u2net':
            print("Model: U2NET (173.6 MB)")
            self.net = U2NET(3, 1)  # 173.6 MB
        elif model_name == 'u2netp':
            print("Model: U2NetP (4.7 MB)")
            self.net = U2NETP(3, 1)  # 4.7 MB
        else:
            raise AssertionError('Invalid "model_name"', 'Use "u2net" or "u2netp"')

        # Load network
        model_file = os.path.join(os.path.dirname(__file__), 'saved_models', model_name + '.pth')
        print("model_file:", model_file)

        if cuda_mode:
            print("CUDA mode")
            self.net.load_state_dict(torch.load(model_file))
            self.net.cuda()
        else:
            print("CPU mode")
            self.net.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))

        self.net.eval()
def main():

    # --------- 1. get image path and name ---------
    model_name='u2net'#u2netp


    image_dir = './test_data/test_images/'
    prediction_dir = './test_data/' + model_name + '_results/'
    model_dir = './saved_models/'+ model_name + '/' + model_name + '.pth'

    img_name_list = glob.glob(image_dir + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:",img_name_list[i_test].split("/")[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

        # normalization
        pred = d1[:,0,:,:]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test],pred,prediction_dir)

        del d1,d2,d3,d4,d5,d6,d7
Пример #3
0
def main():

    model_name = 'u2net'  # u2netp
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    cap = cv2.VideoCapture(0)

    import time

    while True:
        ret, frame = cap.read()
        if ret:
            t0 = time.time()
            img = cv2.resize(frame, (320, 320))
            img = img.transpose((2, 0, 1))
            img = img[None, ...] / 255.
            inputs_test = torch.from_numpy(img)
            inputs_test = inputs_test.type(torch.FloatTensor)

            if torch.cuda.is_available():
                inputs_test = Variable(inputs_test.cuda())
            else:
                inputs_test = Variable(inputs_test)

            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

            pred = d1[:, 0, :, :]
            predict = normalize(pred)
            predict = predict.squeeze()
            predict_np = predict.cpu().data.numpy()

            h, w = frame.shape[:2]
            pred_resized = cv2.resize(predict_np, (w, h))

            img = (frame.astype(np.float32) * np.dstack(
                (pred_resized, pred_resized, pred_resized))).astype(np.uint8)

            cv2.imshow("out", img)
            cv2.waitKey(1)

            del d1, d2, d3, d4, d5, d6, d7
Пример #4
0
def main():
    # --------- 1. get image path and name ---------
    model_name = 'u2netp'  # u2netp

    image_dir = '../train2014'
    prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
    model_dir = '../models/' + model_name + '.pth'

    img_name_list = glob.glob(image_dir + os.sep + '*')

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    # if torch.cuda.is_available():
    #     net.cuda()
    net.eval()
    all_out = {}
    for i_test, data_test in tqdm(enumerate(test_salobj_dataloader)):
        sep_ = img_name_list[i_test].split(os.sep)[-1]
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)
        #
        # if torch.cuda.is_available():
        #     inputs_test = Variable(inputs_test.cuda())
        # else:
        inputs_test = Variable(inputs_test)
        d = net(inputs_test)

        pred = normPRED(d)
        all_out[sep_] = pred

    pickle.dump(all_out, open("../data/coco_train_u2net.pik", "wb"), protocol=2)
Пример #5
0
def model(model_name='u2net'):

    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))

    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    return net
Пример #6
0
def model(model_name="u2net"):

    model_dir = os.path.join(os.getcwd(), "saved_models", model_name,
                             model_name + ".pth")

    if model_name == "u2net":
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif model_name == "u2netp":
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))

    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    return net
Пример #7
0
    def __init__(self, model_name):
        self.model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                                      model_name + '.pth')

        if model_name == 'u2net':
            print("...load U2NET---173.6 MB")
            net = U2NET(3, 1)
        elif model_name == 'u2netp':
            print("...load U2NEP---4.7 MB")
            net = U2NETP(3, 1)

        net.load_state_dict(
            torch.load(self.model_dir, map_location=torch.device('cpu')))
        if torch.cuda.is_available():
            net.cuda()
        net.eval()

        self.net = net
Пример #8
0
def main():
    model_name='u2netp'

    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    img_name_list = glob.glob(image_dir + os.sep + '*')

    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                        ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)

    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    for data_test in test_salobj_dataloader:
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)
        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        dummy_input = inputs_test
        torch.onnx.export(net, dummy_input,"exported/onnx/{}.onnx".format(model_name),
        opset_version=10)
        
        break
Пример #9
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'  # u2netp

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data',
                                  model_name + '_results' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------

    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))

    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)

    for i_test, data_test in enumerate(len(img_name_list)):
        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        d1, d2, d3, d4, d5, d6, d7 = net.predict(test_salobj_dataset, steps=1)
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)
        del d1, d2, d3, d4, d5, d6, d7
Пример #10
0
def main():
    # --------- 1. get image path and name ---------
    model_name = 'u2netp'  # u2net

    # image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    image_dir = '/home/hha/dataset/circle/circle'
    prediction_dir = '/home/hha/dataset/circle/circle_pred'
    model_dir = '/home/hha/pytorch_code/U-2-Net-master/saved_models/u2netp/u2netp.pthu2netp_bce_itr_2000_train_0.077763_tar_0.006976.pth'
    # model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')

    img_name_list = list(filter(lambda f: f.find('_mask') < 0, img_name_list))

    # print(img_name_list)

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([
            RescaleT(320),  # 320
            ToTensorLab(flag=0)
        ]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=0)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = inputs_test.cuda()
        else:
            inputs_test = inputs_test

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        # pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Пример #11
0
def main():

    # ------- 2. set the directory of training dataset --------

    model_name = 'u2net'  #'u2netp'

    data_dir = './train_data/'
    tra_image_dir = 'DUTS/DUTS-TR/DUTS-TR/im_aug/'
    tra_label_dir = 'DUTS/DUTS-TR/DUTS-TR/gt_aug/'

    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = './saved_models/' + model_name + '/'

    epoch_num = 100000
    batch_size_train = 12
    batch_size_val = 1
    train_num = 0
    val_num = 0

    tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)

    tra_lbl_name_list = []
    for img_path in tra_img_name_list:
        img_name = img_path.split("/")[-1]

        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1, len(bbb)):
            imidx = imidx + "." + bbb[i]

        tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)

    print("---")
    print("train images: ", len(tra_img_name_list))
    print("train labels: ", len(tra_lbl_name_list))
    print("---")

    train_num = len(tra_img_name_list)

    salobj_dataset = SalObjDataset(img_name_list=tra_img_name_list,
                                   lbl_name_list=tra_lbl_name_list,
                                   transform=transforms.Compose([
                                       RescaleT(320),
                                       RandomCrop(288),
                                       ToTensorLab(flag=0)
                                   ]))
    salobj_dataloader = DataLoader(salobj_dataset,
                                   batch_size=batch_size_train,
                                   shuffle=True,
                                   num_workers=1)

    # ------- 3. define model --------
    # define the net
    if (model_name == 'u2net'):
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        net = U2NETP(3, 1)

    if torch.cuda.is_available():
        net.cuda()

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(),
                           lr=0.001,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)

    # ------- 5. training process --------
    print("---start training...")
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 0
    save_frq = 2000  # save the model every 2000 iterations

    for epoch in range(0, epoch_num):
        net.train()

        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1

            inputs, labels = data['image'], data['label']

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)

            # wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v = Variable(inputs.cuda(),
                                              requires_grad=False), Variable(
                                                  labels.cuda(),
                                                  requires_grad=False)
            else:
                inputs_v, labels_v = Variable(
                    inputs, requires_grad=False), Variable(labels,
                                                           requires_grad=False)

            # y zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
            loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6,
                                               labels_v)

            loss.backward()
            optimizer.step()

            # # print statistics
            running_loss += loss.data[0]
            running_tar_loss += loss2.data[0]

            # del temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss

            print(
                "[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f "
                % (epoch + 1, epoch_num,
                   (i + 1) * batch_size_train, train_num, ite_num,
                   running_loss / ite_num4val, running_tar_loss / ite_num4val))

            if ite_num % save_frq == 0:

                torch.save(
                    net.state_dict(), model_dir + model_name +
                    "_bce_itr_%d_train_%3f_tar_%3f.pth" %
                    (ite_num, running_loss / ite_num4val,
                     running_tar_loss / ite_num4val))
                running_loss = 0.0
                running_tar_loss = 0.0
                net.train()  # resume train
                ite_num4val = 0
Пример #12
0
def run_model():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'

    image_dir = os.path.join(os.getcwd(), 'static', 'uploads')
    prediction_dir = os.path.join(os.getcwd(), 'static', "final" + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    #print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()

    # --------- 4. inference for each image ---------
    final_dict = {}
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        final_dict[img_name_list[i_test].split(os.sep)[-1]] = convert_img(
            img_name_list[i_test], pred, prediction_dir, img_name_list[i_test])

        del d1, d2, d3, d4, d5, d6, d7

    # returns dict mapping file name to [original image, nobg image, checker, whiteg, blackg]
    return final_dict
Пример #13
0
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
        RescaleT(320),
        RandomCrop(288),
        ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True
                               #shuffle=False
                               , num_workers=0)

# ------- 3. define model --------
# define the net
#选择模型
if(model_name=='u2net'):
    net = U2NET(3, 1)
elif(model_name=='u2netp'):
    net = U2NETP(3,1)

if torch.cuda.is_available():
    net.cuda()

# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 2000 # save the model every 2000 iterations  每迭代两千次存一次模型,这个可以在下面改成每个epoch存,很好改,这里我就不改了
Пример #14
0
def main():

    # --------- 1. get image path and name ---------
    #model_name='u2net'
    model_name = 'u2netp'

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data',
                                  model_name + '_results' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        #print("test", inputs_test.shape)
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)
        #print("pred",pred.shape)
        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Пример #15
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'  #u2netp

    image_dir = "/home/vybt/Downloads/U2_Net_Test"
    prediction_dir = "/home/vybt/Downloads/u-2--bps-net-prediction"
    model_dir = '/media/vybt/DATA/SmartFashion/deep-learning-projects/U-2-Net/saved_models/u2net/_bps_bce_itr_300000_train_0.107041_tar_0.011690.pth'

    img_name_list = glob.glob(image_dir + os.sep + '*')
    # print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if model_name == 'u2net':
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 8)
    elif model_name == 'u2netp':
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 8)

    net.load_state_dict(torch.load(model_dir))

    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("Inference: ", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        # print("inputs test: {}".format(inputs_test.shape))
        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Пример #16
0
def main():
    name = 'test'
    # please input the height  unit:M
    body_height = 1.63

    #"Input image" path
    image_dir = os.path.join(os.getcwd(), 'input')

    #"Output model" path
    outbody_filenames = './output/{}.obj'.format(name)

    #########################################################
    #this code used for image segmentation to remove the background to get Silhouettes
    # --------- 1. get image path and name ---------

    model_name='u2net'#u2net or u2netp

    #set orignal silhouette images path
    prediction_dir1 = os.path.join(os.getcwd(), 'Silhouette' + os.sep)

    #set the path of silhouette images after horizontal flippath
    prediction_dir = os.path.join(os.getcwd(), 'test_data' + os.sep)

    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------

    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. silhouette cutting model define ---------
    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

        # normalization
        pred = d1[:,0,:,:]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test],pred,prediction_dir,prediction_dir1)

        del d1,d2,d3,d4,d5,d6,d7


    ###########################################################
    #this code used for reconstruct 3d model:

   #--------5.get the silhouette images --------
    img_filenames = ['./test_data/front.png', './test_data/side.png']

    # img = cv2.imread(img_filenames[1])
    # cv2.flip(img,1)



    # -----------6.load input data---------
    sampling_num = 648
    data = np.zeros([2, 2, sampling_num])
    for i in np.arange(len(img_filenames)):
        img = img_filenames[i]
        im = getBinaryimage(img, 600)  # deal with white-black image simply
        sample_points = getSamplePoints(im, sampling_num, i)
        center_p = np.mean(sample_points, axis=0)
        sample_points = sample_points - center_p
        data[i, :, :] = sample_points.T

    data = repeat_data(data)

    #--------7 load CNN model----reconstruct 3d body shape
    print('==> begining...')
    len_out = 22
    model_name = './Models/model.ckpt'
    ourModel = RegressionPCA(len_out)
    ourModel.load_state_dict(torch.load(model_name))
    ourModel.eval()

    #----------8 output results--------------
    save_obj(outbody_filenames, ourModel, body_height, data)
Пример #17
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2netp'  #u2netp

    image_dir = './data/workbench/'
    prediction_dir = './data/workbench_out/'
    model_dir = './saved_models/' + model_name + '/' + model_name + '.pth'

    img_name_list = glob.glob(image_dir + '*')
    print(img_name_list)
    # TODO: consider data loader over sets of videos

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    batch_size = 1
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=3)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    from datetime import datetime
    a = datetime.now()
    total_inf = 0
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing frame:", i_test * batch_size)
        # print("dl:", datetime.now()-a)
        a = datetime.now()
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)
        if torch.cuda.is_available():
            inputs_test = inputs_test.cuda()

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        print("inf:", total_inf / (i_test + 1))
        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)
        total_inf += (datetime.now() - a).microseconds

        # save results to test_results folder
        # TODO: dynamically remember input sizes somehow, hardcoded for now
        for j in range(pred.shape[0]):
            save_output(img_name_list[batch_size * i_test + j], pred[j:j + 1],
                        prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
        a = datetime.now()
Пример #18
0
def main():
    # --------- 1. get image path and name ---------
    model_name = 'u2net'  # u2netp

    image_dir_0 = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir_0 = os.path.join(os.getcwd(), 'test_data',
                                    model_name + '_results' + os.sep)
    model_dir_0 = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                               'saved_models', model_name + '.pth')

    parser = argparse.ArgumentParser()

    parser.add_argument("-m",
                        "--model_path",
                        default=model_dir_0,
                        help="pretrained model as .pth file")
    parser.add_argument("-i",
                        "--input_path",
                        default=image_dir_0,
                        help="folder with input images")
    parser.add_argument("-o",
                        "--output_path",
                        default=prediction_dir_0,
                        help="folder with output images")

    args = parser.parse_args()

    image_dir = args.input_path
    prediction_dir = args.output_path
    pretrained_model_path = args.model_path

    if not os.path.exists(pretrained_model_path):
        print(
            f"Could not find pretrained U^2 Net model at {pretrained_model_path}. "
            f"Please run CMake to download the default pretrained model or specify path to the pretrained model"
            f" as the -m argument to the u2net_test script.")
        return -1

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if model_name == 'u2net':
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif model_name == 'u2netp':
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    else:
        raise ValueError(f"Unsuppored model name: {model_name}")

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(pretrained_model_path))
        net.cuda()
    else:
        net.load_state_dict(
            torch.load(pretrained_model_path, map_location='cpu'))
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
    return 0
Пример #19
0
def main():
    args = get_parameters()
    # --------- 1. get image path and name ---------
    model_name = 'u2net'  # u2netp
    error_file_link = args.errorFile
    img_name_list = []
    with open(args.input, 'r') as file:
        for line in file:
            line = line.strip()  # preprocess line
            img_name_list.append(line)
    prediction_dir = args.output_dir
    model_dir = './saved_models/' + model_name + '.pth'
    #print(img_name_list)
    print("Num of image paths in ", str(args.input), "is: ",
          len(img_name_list))

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))

    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1,
                                        collate_fn=my_collate)

    # --------- 3. model define ---------
    if model_name == 'u2net':
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif model_name == 'u2netp':
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):
        try:
            print("\r------In processing file {} with name {}--------".format(
                i_test + 1, img_name_list[i_test].split("/")[-1]),
                  end='')

            inputs_test = data_test['image']
            inputs_test = inputs_test.type(torch.FloatTensor)

            if torch.cuda.is_available():
                inputs_test = Variable(inputs_test.cuda())
            else:
                inputs_test = Variable(inputs_test)

            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

            # normalization
            pred = d1[:, 0, :, :]
            pred = normPRED(pred)

            # save results to test_results folder
            save_output(img_name_list[i_test], pred, prediction_dir)

            del d1, d2, d3, d4, d5, d6, d7
        except Exception as error:
            print(error)
            with open(error_file_link, 'a+') as err_file:
                error_mess = img_name_list[i_test] + '*' + str(error) + '\n'
                err_file.write(error_mess)
            continue
Пример #20
0
def main():
    # ---------------------------------------------------------
    # Configurations
    # ---------------------------------------------------------

    heavy_augmentation = True  # False to use author's default implementation
    gan_training = False
    mixup_augmentation = False
    fullsize_training = False
    multiscale_training = False
    multi_gpu = True
    mixed_precision_training = True

    model_name = "u2net"  # "u2net", "u2netp", "u2net_heavy"
    se_type = None  # "csse", "sse", "cse", None; None to use author's default implementation
    # checkpoint = "saved_models/u2net/u2net.pth"
    checkpoint = None
    checkpoint_netD = None

    w_adv = 0.2
    w_vgg = 0.2

    train_dirs = [
        "../datasets/sky_segmentation_dataset/datasets/cvprw2020_sky_seg/train/"
    ]
    train_dirs_file_limit = [
        None,
    ]

    image_ext = '.jpg'
    label_ext = '.png'
    dataset_name = "cvprw2020_sky_seg"

    lr = 0.0003
    epoch_num = 500
    batch_size_train = 48
    # batch_size_val = 1
    workers = 16
    save_frq = 1000  # save the model every 2000 iterations

    save_debug_samples = False
    debug_samples_dir = "./debug/"

    # ---------------------------------------------------------

    model_dir = './saved_models/' + model_name + '/'
    os.makedirs(model_dir, exist_ok=True)

    writer = SummaryWriter()

    if fullsize_training:
        batch_size_train = 1
        multiscale_training = False

    # ---------------------------------------------------------
    # 1. Construct data input pipeline
    # ---------------------------------------------------------

    # Get dataset name
    dataset_name = dataset_name.replace(" ", "_")

    # Get training data
    assert len(train_dirs) == len(train_dirs_file_limit), \
        "Different train dirs and train dirs file limit length!"

    tra_img_name_list = []
    tra_lbl_name_list = []
    for d, flimit in zip(train_dirs, train_dirs_file_limit):
        img_files = glob.glob(d + '**/*' + image_ext, recursive=True)
        if flimit:
            img_files = np.random.choice(img_files, size=flimit, replace=False)

        print(f"directory: {d}, files: {len(img_files)}")

        for img_path in img_files:
            lbl_path = img_path.replace("/image/", "/alpha/") \
                .replace(image_ext, label_ext)

            if os.path.exists(img_path) and os.path.exists(lbl_path):
                assert os.path.splitext(
                    os.path.basename(img_path))[0] == os.path.splitext(
                        os.path.basename(lbl_path))[0], "Wrong filename."

                tra_img_name_list.append(img_path)
                tra_lbl_name_list.append(lbl_path)
            else:
                print(
                    f"Warning, dropping sample {img_path} because label file {lbl_path} not found!"
                )

    tra_img_name_list, tra_lbl_name_list = shuffle(tra_img_name_list,
                                                   tra_lbl_name_list)

    train_num = len(tra_img_name_list)
    # val_num = 0  # unused
    print(f"dataset name        : {dataset_name}")
    print(f"training samples    : {train_num}")

    # Construct data input pipeline
    if heavy_augmentation:
        transform = AlbuSampleTransformer(
            get_heavy_transform(
                fullsize_training=fullsize_training,
                transform_size=False if
                (fullsize_training or multiscale_training) else True))
    else:
        transform = transforms.Compose([
            RescaleT(320),
            RandomCrop(288),
        ])

    # Create dataset and dataloader
    dataset_kwargs = dict(img_name_list=tra_img_name_list,
                          lbl_name_list=tra_lbl_name_list,
                          transform=transforms.Compose([
                              transform,
                          ] + ([
                              SaveDebugSamples(out_dir=debug_samples_dir),
                          ] if save_debug_samples else []) + ([
                              ToTensorLab(flag=0),
                          ] if not multiscale_training else [])))
    if mixup_augmentation:
        _dataset_cls = MixupAugSalObjDataset
    else:
        _dataset_cls = SalObjDataset

    salobj_dataset = _dataset_cls(**dataset_kwargs)
    salobj_dataloader = DataLoader(
        salobj_dataset,
        batch_size=batch_size_train,
        collate_fn=multi_scale_collater if multiscale_training else None,
        shuffle=True,
        pin_memory=True,
        num_workers=workers)

    # ---------------------------------------------------------
    # 2. Load model
    # ---------------------------------------------------------

    # Instantiate model
    if model_name == "u2net":
        net = U2NET(3, 1, se_type=se_type)
    elif model_name == "u2netp":
        net = U2NETP(3, 1, se_type=se_type)
    elif model_name == "u2net_heavy":
        net = u2net_heavy()
    elif model_name == "custom":
        net = CustomNet()
    else:
        raise ValueError(f"Unknown model_name: {model_name}")

    # Restore model weights from checkpoint
    if checkpoint:
        if not os.path.exists(checkpoint):
            raise FileNotFoundError(f"Checkpoint file not found: {checkpoint}")

        try:
            print(f"Restoring from checkpoint: {checkpoint}")
            net.load_state_dict(torch.load(checkpoint, map_location="cpu"))
            print(" - [x] success")
        except:
            print(" - [!] error")

    if torch.cuda.is_available():
        net.cuda()

    if gan_training:
        netD = MultiScaleNLayerDiscriminator()

        if checkpoint_netD:
            if not os.path.exists(checkpoint_netD):
                raise FileNotFoundError(
                    f"Discriminator checkpoint file not found: {checkpoint_netD}"
                )

            try:
                print(
                    f"Restoring discriminator from checkpoint: {checkpoint_netD}"
                )
                netD.load_state_dict(
                    torch.load(checkpoint_netD, map_location="cpu"))
                print(" - [x] success")
            except:
                print(" - [!] error")

        if torch.cuda.is_available():
            netD.cuda()

        vgg19 = VGG19Features()
        vgg19.eval()
        if torch.cuda.is_available():
            vgg19 = vgg19.cuda()

    # ---------------------------------------------------------
    # 3. Define optimizer
    # ---------------------------------------------------------

    optimizer = optim.Adam(net.parameters(),
                           lr=lr,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)
    # optimizer = optim.SGD(net.parameters(), lr=lr)
    # scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr/4, max_lr=lr,
    #                                         mode="triangular2",
    #                                         step_size_up=2 * len(salobj_dataloader))

    if gan_training:
        optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.9))

    # ---------------------------------------------------------
    # 4. Initialize AMP and data parallel stuffs
    # ---------------------------------------------------------

    GOT_AMP = False
    if mixed_precision_training:
        try:
            print("Checking for Apex AMP support...")
            from apex import amp
            GOT_AMP = True
            print(" - [x] yes")
        except ImportError:
            print(" - [!] no")

    if GOT_AMP:
        amp.register_float_function(torch, 'sigmoid')
        net, optimizer = amp.initialize(net, optimizer, opt_level="O1")

        if gan_training:
            netD, optimizerD = amp.initialize(netD, optimizerD, opt_level="O1")
            vgg19 = amp.initialize(vgg19, opt_level="O1")

    if torch.cuda.device_count() > 1 and multi_gpu:
        print(f"Multi-GPU training using {torch.cuda.device_count()} GPUs.")
        net = nn.DataParallel(net)

        if gan_training:
            netD = nn.DataParallel(netD)
            vgg19 = nn.DataParallel(vgg19)
    else:
        print(f"Training using {torch.cuda.device_count()} GPUs.")

    # ---------------------------------------------------------
    # 5. Training
    # ---------------------------------------------------------

    print("Start training...")

    ite_num = 0
    ite_num4val = 0
    running_loss = 0.0
    running_bce_loss = 0.0
    running_tar_loss = 0.0
    running_adv_loss = 0.0
    running_per_loss = 0.0
    running_fake_loss = 0.0
    running_real_loss = 0.0
    running_lossD = 0.0

    for epoch in tqdm(range(0, epoch_num), desc="All epochs"):
        net.train()
        if gan_training:
            netD.train()

        for i, data in enumerate(
                tqdm(salobj_dataloader, desc=f"Epoch #{epoch}")):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1

            image_key = "image"
            label_key = "label"
            inputs, labels = data[image_key], data[label_key]
            # tqdm.write(f"input tensor shape: {inputs.shape}")

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)

            # Wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v = \
                    Variable(inputs.cuda(), requires_grad=False), \
                    Variable(labels.cuda(), requires_grad=False)
            else:
                inputs_v, labels_v = \
                    Variable(inputs, requires_grad=False), \
                    Variable(labels, requires_grad=False)

            # # Zero the parameter gradients
            # optimizer.zero_grad()

            # Forward + backward + optimize

            d6 = 0
            if model_name == "custom":
                d0, d1, d2, d3, d4, d5 = net(inputs_v)
            else:
                d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)

            if gan_training:
                optimizerD.zero_grad()

                dis_fake = netD(inputs_v, d0.detach())
                dis_real = netD(inputs_v, labels_v)

                loss_fake = bce_with_logits_loss(dis_fake,
                                                 torch.zeros_like(dis_fake))
                loss_real = bce_with_logits_loss(dis_real,
                                                 torch.ones_like(dis_real))
                lossD = loss_fake + loss_real

                if GOT_AMP:
                    with amp.scale_loss(lossD, optimizerD) as scaled_loss:
                        scaled_loss.backward()
                else:
                    lossD.backward()

                optimizerD.step()

                writer.add_scalar("lossD/fake", loss_fake.item(), ite_num)
                writer.add_scalar("lossD/real", loss_real.item(), ite_num)
                writer.add_scalar("lossD/sum", lossD.item(), ite_num)
                running_fake_loss += loss_fake.item()
                running_real_loss += loss_real.item()
                running_lossD += lossD.item()

            # Zero the parameter gradients
            optimizer.zero_grad()

            if model_name == "custom":
                loss2, loss = multi_bce_loss_fusion5(d0, d1, d2, d3, d4, d5,
                                                     labels_v)
            else:
                loss2, loss = multi_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6,
                                                    labels_v)

            writer.add_scalar("lossG/bce", loss.item(), ite_num)
            running_bce_loss += loss.item()

            if gan_training:
                # Adversarial loss
                loss_adv = 0.0
                if w_adv:
                    dis_fake = netD(inputs_v, d0)
                    loss_adv = bce_with_logits_loss(dis_fake,
                                                    torch.ones_like(dis_fake))

                # Perceptual loss
                loss_per = 0.0
                if w_vgg:
                    vgg19_fm_pred = vgg19(inputs_v * d0)
                    vgg19_fm_label = vgg19(inputs_v * labels_v)
                    loss_per = mae_loss(vgg19_fm_pred, vgg19_fm_label)

                loss = loss + w_adv * loss_adv + w_vgg * loss_per

            if GOT_AMP:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            # scheduler.step()

            writer.add_scalar("lossG/sum", loss.item(), ite_num)
            writer.add_scalar("lossG/loss2", loss2.item(), ite_num)
            running_loss += loss.item()
            running_tar_loss += loss2.item()
            if gan_training:
                writer.add_scalar("lossG/adv", loss_adv.item(), ite_num)
                writer.add_scalar("lossG/perceptual", loss_per.item(), ite_num)
                running_adv_loss += loss_adv.item()
                running_per_loss += loss_per.item()

            if ite_num % 200 == 0:
                writer.add_images("inputs", inv_normalize(inputs_v), ite_num)
                writer.add_images("labels", labels_v, ite_num)
                writer.add_images("preds", d0, ite_num)

            # Delete temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss
            if gan_training:
                del dis_fake, dis_real, loss_fake, loss_real, lossD, loss_adv, vgg19_fm_pred, vgg19_fm_label, loss_per

            # Print stats
            tqdm.write(
                "[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train G/sum: %3f, G/bce: %3f, G/bce_tar: %3f, G/adv: %3f, G/percept: %3f, D/fake: %3f, D/real: %3f, D/sum: %3f"
                % (epoch + 1, epoch_num,
                   (i + 1) * batch_size_train, train_num, ite_num,
                   running_loss / ite_num4val, running_bce_loss / ite_num4val,
                   running_tar_loss / ite_num4val, running_adv_loss /
                   ite_num4val, running_per_loss / ite_num4val,
                   running_fake_loss / ite_num4val, running_real_loss /
                   ite_num4val, running_lossD / ite_num4val))

            if ite_num % save_frq == 0:
                # Save checkpoint
                torch.save(
                    net.module.state_dict() if hasattr(
                        net, "module") else net.state_dict(), model_dir +
                    model_name + (("_" + se_type) if se_type else "") +
                    ("_" + dataset_name) +
                    ("_mixup_aug" if mixup_augmentation else "") +
                    ("_heavy_aug" if heavy_augmentation else "") +
                    ("_fullsize" if fullsize_training else "") +
                    ("_multiscale" if multiscale_training else "") +
                    "_bce_itr_%d_train_%3f_tar_%3f.pth" %
                    (ite_num, running_loss / ite_num4val,
                     running_tar_loss / ite_num4val))

                if gan_training:
                    torch.save(
                        netD.module.state_dict() if hasattr(netD, "module")
                        else netD.state_dict(), model_dir + "netD_" +
                        model_name + (("_" + se_type) if se_type else "") +
                        ("_" + dataset_name) +
                        ("_mixup_aug" if mixup_augmentation else "") +
                        ("_heavy_aug" if heavy_augmentation else "") +
                        ("_fullsize" if fullsize_training else "") +
                        ("_multiscale" if multiscale_training else "") +
                        "itr_%d.pth" % (ite_num))

                # Reset stats
                running_loss = 0.0
                running_bce_loss = 0.0
                running_tar_loss = 0.0
                running_adv_loss = 0.0
                running_per_loss = 0.0
                running_fake_loss = 0.0
                running_real_loss = 0.0
                running_lossD = 0.0
                ite_num4val = 0

                net.train()  # resume train
                if gan_training:
                    netD.train()

    writer.close()
    print("Training completed successfully.")
Пример #21
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'  #u2netp  #模型名字和训练一致

    # image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    image_dir = 'D:/wcs/U-2-Net/data/RIVER/Test/Image/'  #测试图像路径
    # prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
    prediction_dir = 'D:/wcs/U-2-Net/data/RIVER/Test/pre/'  #保存结果路径
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    # print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    # net.load_state_dict(torch.load(model_dir))
    net.load_state_dict(
        torch.load(
            './saved_models/u2net/u2net_bce_itr_36000_train_0.091362_tar_0.003286.pth'
        ))  #加载自己的模型
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        # print(d7.shape)
        pred = d1[:,
                  0, :, :]  #注意这里,这个d1是融合了d2,d3,d4,d5,d6,d7的,如果想了解具体就翻到网络模型去自习看看
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Пример #22
0
def main():

    full_sized_testing = False
    model_name = 'u2net'
    # model_name = 'u2netp'

    # image_dir = './test_images/'
    # image_dir = "../detectron2_mask_prediction/predictions/image/"
    image_dir = "../datasets/DUTS-TE/image/"
    # prediction_dir = './test_data/' + model_name + '_results/'
    # model_dir = './saved_models/' + model_name + '/' + model_name + '.pth'

    model_dir = "./saved_models/u2net/u2net.pth"
    # model_dir = "./u2net_mixed_person_n_portraits_heavy_aug_multiscale_bce_itr_8000_train_0.384033_tar_0.046964.pth"
    assert os.path.isfile(model_dir)
    prediction_dir = f"./predictions{'_fullsize' if full_sized_testing else ''}_{os.path.splitext(os.path.basename(model_dir))[0]}/"

    os.makedirs(prediction_dir, exist_ok=True)

    img_name_list = glob.glob(image_dir + '*')

    img_exts = [".jpg", ".jpeg", ".png", ".jfif"]
    img_name_list = list(
        filter(lambda p: os.path.splitext(p)[-1].lower() in img_exts,
               img_name_list))
    # print(img_name_list)

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose(([] if full_sized_testing else [
            RescaleT(320),
        ]) + [
            ToTensorLab(flag=0),
        ]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1, se_type=None)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    elif (model_name == 'custom'):
        net = CustomNet()
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split("/")[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d7 = 0
        if model_name == "custom":
            d1, d2, d3, d4, d5, d6 = net(inputs_test)
        else:
            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Пример #23
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2netp'  # u2netp u2net
    data_dir = '/data2/wangjiajie/datasets/scene_segment1023/u2data/'
    image_dir = os.path.join(data_dir, 'test_imgs')
    prediction_dir = os.path.join('./outputs/', model_name + '/')
    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir, exist_ok=True)
    # tra_label_dir = 'test_lbls/'

    image_ext = '.jpg'
    # label_ext = '.jpg' # '.png'
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(f'test img numbers are: {len(img_name_list)}')

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=Compose([
                                            SmallestMaxSize(max_size=320),
                                        ]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)

    # net.load_state_dict(torch.load(model_dir))
    checkpoint = torch.load(model_dir)
    d = collections.OrderedDict()
    for key, value in checkpoint.items():
        tmp = key[7:]
        d[tmp] = value
    net.load_state_dict(d)
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = 1.0 - d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Пример #24
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'  #u2netp

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data',
                                  model_name + '_results' + os.sep)
    model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'
    #os.path.join(os.getcwd(), 'saved_models', model_name, 'u2net_bce_itr_8000_train_1.003468_tar_0.108501' + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    transforms = A.Compose([
        A.Resize(512, 512),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ],
                           p=1.0)
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=transforms)
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        #pred = 1.0 - d1[:,0,:,:]#reversed
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Пример #25
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'  #u2netp
    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data',
                                  model_name + '_results' + os.sep)
    # image_dir = os.path.join('/nfs/project/huxiaoliang/data/white_or_not/white_bg_image')
    # prediction_dir = os.path.join('/nfs/project/huxiaoliang/data/white_or_not/white_bg_image_pred'+ os.sep)
    model_dir = os.path.join('/nfs/private/modelfiles/u2net-saved_models',
                             model_name, model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    error = []
    for i_test, data_test in enumerate(test_salobj_dataloader):
        try:
            print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

            inputs_test = data_test['image']
            inputs_test = inputs_test.type(torch.FloatTensor)

            if torch.cuda.is_available():
                inputs_test = Variable(inputs_test.cuda())
            else:
                inputs_test = Variable(inputs_test)

            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

            # normalization
            pred = d1[:, 0, :, :]
            pred = normPRED(pred)

            # save results to test_results folder
            if not os.path.exists(prediction_dir):
                os.makedirs(prediction_dir, exist_ok=True)
            save_output(img_name_list[i_test], pred, prediction_dir)

            del d1, d2, d3, d4, d5, d6, d7
        except Exception as ex:
            traceback.print_exc()
            error.append(img_name_list[i_test].split(os.sep)[-1])
    print('异常数据:', error)
Пример #26
0
                    saturation=(0.9, 1.1),
                    hue=0.05),
        ToTensorLab(flag=0)
    ]))
salobj_dataloader = DataLoader(salobj_dataset,
                               batch_size=batch_size_train,
                               shuffle=True,
                               num_workers=1)

# ------- 3. define model --------
# define the net
if (model_name == 'u2net'):
    net = U2NET(4, 1)
    net.load_state_dict(torch.load(model_dir), strict=False)
elif (model_name == 'u2netp'):
    net = U2NETP(4, 1)
    net.load_state_dict(torch.load(model_dir), strict=False)

if torch.cuda.is_available():
    net.cuda()

# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(),
                       lr=0.001,
                       betas=(0.9, 0.999),
                       eps=1e-08,
                       weight_decay=0)  # lr = 0.001

# ------- 5. training process --------
print("---start training...")
Пример #27
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2netp'  # fixed as u2netp

    image_dir = os.path.join(
        os.getcwd(), 'input'
    )  # changed to 'images' directory which is populated while running the script
    prediction_dir = os.path.join(
        os.getcwd(), 'output/'
    )  # changed to 'results' directory which is populated after the predictions
    model_dir = os.path.join(os.getcwd(), model_name +
                             '.pth')  # path to u2netp pretrained weights

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    net = U2NETP(3, 1)
    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(
            torch.load(model_dir, map_location=torch.device('cpu')))

    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Пример #28
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'rgbd_u2net'  #u2netp

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data',
                                  model_name + '_results' + os.sep)
    # model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '_s6457_2021-02-18_12_36_53u2net_6457_bce_itr_38000_train_0.107382_tar_0.009159.pth')
    model_dir = 'saved_models/rgbd_u2net/rgbd_u2net_s30857_2021-04-06_03_52_47/rgbd_u2net_30857_bce_itr_220000_train_0.089360_tar_0.008674.pth'
    # path_files = Path('/pool/2021-03-31_22-11-41')
    path_files = Path('/pool/2021-03-31_22-32-41')
    img_name_list1 = sorted([str(x) for x in path_files.rglob('**/rgb/*.png')])
    img_name_list1 = [
        str(x) for x in img_name_list1
        if '2021-03-31' in str(x) or '2021-04-01' in str(x)
    ]

    # path_files2 = Path('/dataset')
    # img_name_list2 = sorted([str(x) for x in path_files2.rglob('**/rgb/*.png') if not Path(str(x).replace('rgb','annotation')).exists()])
    # img_name_list2 = [str(x) for x in img_name_list2 if '2021-03-10' in str(x)]

    img_name_list = img_name_list1  #+ img_name_list2
    print('img len', len(img_name_list))
    depth_name_list = [
        x.replace('rgb', 'aligned_depth') for x in img_name_list
    ]
    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = RGBD_SalObjDataset(img_name_list=img_name_list,
                                             depth_name_list=depth_name_list,
                                             lbl_name_list=[],
                                             transform=transforms.Compose([
                                                 RGBD_RescaleT(320),
                                                 RGBD_ToTensorLab(flag=0)
                                             ]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'rgbd_u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(4, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(tqdm(test_salobj_dataloader)):

        # print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        depth = data_test['depth']
        inputs_test = torch.cat((inputs_test, torch.unsqueeze(depth, dim=1)),
                                dim=1)  # H x W x 4

        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)
        with torch.no_grad():
            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir, i_test)

        del d1, d2, d3, d4, d5, d6, d7
Пример #29
0
def main(model_name, img_dir, retrain, weight, model_dir):
    model_name = 'u2net' #'u2netp'
    #data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
    # tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
    # tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)
    tra_image_dir = os.path.join(img_dir, 'origin')
    tra_label_dir = os.path.join(img_dir, 'mask')
    # train_image_dir = os.path.join('')


    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)

    epoch_start = 0
    epoch_num = 500
    batch_size_train = 20
    batch_size_val = 1
    train_num = 4000
    val_num = 500

    # tra_img_name_list = glob.glob(tra_image_dir + '*' + image_ext)
    tra_img_name_list = os.listdir(tra_image_dir)
    for i,item in enumerate(tra_img_name_list):
        tra_img_name_list[i] = os.path.join(tra_image_dir, item)


    tra_lbl_name_list = os.listdir(tra_label_dir)
    for i,item in enumerate(tra_lbl_name_list):
        tra_lbl_name_list[i] = os.path.join(tra_label_dir, item)

    print(tra_img_name_list)
    # for img_path in tra_img_name_list:
    # 	img_name = img_path.split(os.sep)[-1]

    # 	aaa = img_name.split(".")
    # 	bbb = aaa[0:-1]
    # 	imidx = bbb[0]
    # 	for i in range(1,len(bbb)):
    # 		imidx = imidx + "." + bbb[i]

    # 	tra_lbl_name_list.append(tra_label_dir + imidx + label_ext)

    print("---")
    print("train images: ", len(tra_img_name_list))
    print("train labels: ", len(tra_lbl_name_list))
    print("---")

    train_num = len(tra_img_name_list)

    salobj_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
    RescaleT(320),
    RandomCrop(288),
    ToTensorLab(flag=0)]))
    salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=4)

    # ------- 3. define model --------
    # define the net
    if(model_name=='u2net'):
        net = U2NET(3, 1)

    elif(model_name=='u2netp'):
        net = U2NETP(3,1)

    if torch.cuda.is_available():
        net.cuda()

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    # ------- 5. training process --------
    print("---start training...")

    if retrain == True:
        checkpoint = torch.load(weight)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        # loss = checkpoint['loss']

    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 0
    save_frq = 2000 # save the model every 2000 iterations



    for epoch in range(0, epoch_num):
        net.train()
    
        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1

            ite_num4val = ite_num4val + 1
            # print(data)
            inputs, labels = data['image'], data['label']

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)

            # wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
                                                                                            requires_grad=False)
            else:
                inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

            # y zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
            loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

            loss.backward()
            optimizer.step()

            # # print statistics
            running_loss += loss.data.item()
            running_tar_loss += loss2.data.item()

            # del temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss

            print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
            epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))

            # if ite_num % save_frq == 0:
            #     # torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
            #     # torch.save({
            #     # 'epoch': epoch,
            #     # 'model_state_dict': net.state_dict(),
            #     # 'optimizer_state_dict': optimizer.state_dict(),
            #     # 'loss': loss,
            #     # }, model_dir + model_name + epoch)

            #     running_loss = 0.0
            #     running_tar_loss = 0.0
            #     net.train()  # resume train
            #     ite_num4val = 0
        torch.save(
        {
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, os.path.join(model_dir, model_name + str(epoch))     
        )
Пример #30
0
def main():

    # --------- 1. get image path and name ---------
    model_name='u2netp'#u2netp

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    # print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)
    
    if torch.cuda.is_available():
        net.load_state_dict(torch.load("saved_models/u2netp/u2netp_bce_itr_10000_train_1.384799_tar_0.185377.pth"))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location=torch.device("cpu")))

    net.eval()

    inference_time = 0.0
    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        # print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = inputs_test.cuda()

        start = time.time()
        with torch.no_grad():
            d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
        inference_time += (time.time() - start)

        # normalization
        pred = d1[:,0,:,:]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test],pred,prediction_dir)

        del d1,d2,d3,d4,d5,d6,d7

    print(
            f"Predicted {len(img_name_list)} images in {inference_time:.2f}s"
    )