Ejemplo n.º 1
0
def main():
    # get the image path list for inference
    image_dir = Path('./test_data/test_portrait_images/your_portrait_im/')
    image_paths = list(image_dir.glob('*'))
    print("Number of images: ", len(image_paths))

    # indicate the output directory
    out_dir = Path('./test_data/test_portrait_images/your_portrait_results')
    out_dir.mkdir(exist_ok=True)

    # Load the cascade face detection model
    face_cascade = cv2.CascadeClassifier('./saved_models/face_detection_cv2/haarcascade_frontalface_default.xml')
    # u2net_portrait path
    model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'

    # load u2net_portrait model
    net = U2NET(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # do the inference one-by-one
    for i in trange(len(image_paths)):
        # load each image
        img = cv2.imread(str(image_paths[i]))[..., ::-1]
        face = detect_single_face(face_cascade, img)
        im_face = crop_face(img, face)
        im_portrait = inference(net, im_face)

        # save the output
        cv2.imwrite(out_dir / (image_paths[i].stem + '.png'), (im_portrait * 255).astype(np.uint8))
Ejemplo n.º 2
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net_portrait'  #u2netp

    image_dir = './test_data/test_portrait_images/portrait_im'
    prediction_dir = './test_data/test_portrait_images/portrait_results'
    if (not os.path.exists(prediction_dir)):
        os.mkdir(prediction_dir)

    model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'

    img_name_list = glob.glob(image_dir + '/*')
    print("Number of images: ", len(img_name_list))

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(512),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------

    print("...load U2NET---173.6 MB")
    net = U2NET(3, 1)

    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = 1.0 - d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 3
0
 def __init__(self, checkpoint_path: str):
     net = U2NET(3, 1)
     net.load_state_dict(_torch.load(checkpoint_path))
     if _torch.cuda.is_available():
         net.cuda()
     net.eval()
     self.net = net
Ejemplo n.º 4
0
    def __init__(self, model_name='u2net', cuda_mode=True, output_format='np'):
        self.model_name = model_name
        self.cuda_mode = cuda_mode and torch.cuda.is_available()  # Fallback to CPU mode, if cuda is not available
        self.trans = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
        self.output_format = output_format

        # Validate
        if output_format not in self.FORMATS:
            raise AssertionError('Invalid "output_format"', 'Use "np" or "pil"')
        if model_name not in self.MODEL_NAMES:
            raise AssertionError('Invalid "model_name"', 'Use "u2net" or "u2netp"')

        if model_name == 'u2net':
            print("Model: U2NET (173.6 MB)")
            self.net = U2NET(3, 1)  # 173.6 MB
        elif model_name == 'u2netp':
            print("Model: U2NetP (4.7 MB)")
            self.net = U2NETP(3, 1)  # 4.7 MB
        else:
            raise AssertionError('Invalid "model_name"', 'Use "u2net" or "u2netp"')

        # Load network
        model_file = os.path.join(os.path.dirname(__file__), 'saved_models', model_name + '.pth')
        print("model_file:", model_file)

        if cuda_mode:
            print("CUDA mode")
            self.net.load_state_dict(torch.load(model_file))
            self.net.cuda()
        else:
            print("CPU mode")
            self.net.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))

        self.net.eval()
Ejemplo n.º 5
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'  #u2netp

    image_dir = './save_images/images/'
    prediction_dir = './save_images/' + model_name + '_result/'
    model_dir = './saved_models/' + model_name + '/' + model_name + '.pth'

    img_name_list = glob.glob(image_dir + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split("/")[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 6
0
 def __init__(self, model_dir, image_size):
     print("Loading U-2-Net...")
     self.image_size = int(image_size)
     self.net = U2NET(3, 1)
     if torch.cuda.is_available():
         self.net.load_state_dict(torch.load(model_dir))
         self.net.cuda()
     else:
         self.net.load_state_dict(
             torch.load(model_dir, map_location=torch.device('cpu')))
     self.net.eval()
Ejemplo n.º 7
0
def main():

    model_name = 'u2net'  # u2netp
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    cap = cv2.VideoCapture(0)

    import time

    while True:
        ret, frame = cap.read()
        if ret:
            t0 = time.time()
            img = cv2.resize(frame, (320, 320))
            img = img.transpose((2, 0, 1))
            img = img[None, ...] / 255.
            inputs_test = torch.from_numpy(img)
            inputs_test = inputs_test.type(torch.FloatTensor)

            if torch.cuda.is_available():
                inputs_test = Variable(inputs_test.cuda())
            else:
                inputs_test = Variable(inputs_test)

            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

            pred = d1[:, 0, :, :]
            predict = normalize(pred)
            predict = predict.squeeze()
            predict_np = predict.cpu().data.numpy()

            h, w = frame.shape[:2]
            pred_resized = cv2.resize(predict_np, (w, h))

            img = (frame.astype(np.float32) * np.dstack(
                (pred_resized, pred_resized, pred_resized))).astype(np.uint8)

            cv2.imshow("out", img)
            cv2.waitKey(1)

            del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 8
0
def main(colored=False, imagepath=''):

    # --------- 1. get image path and name ---------
    model_name='u2net'#u2netp

    prediction_dir = './test_data/' + model_name + '_results/'
    model_dir = './saved_models/'+ model_name + '/' + model_name + '.pth'

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list = [imagepath],
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)
    net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for _, data_test in enumerate(test_salobj_dataloader):
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

        # normalization
        pred = d1[:,0,:,:]
        pred = normPRED(pred)

        del d1,d2,d3,d4,d5,d6,d7
        # save results to test_results folder
        return save_output(imagepath, pred, prediction_dir, colored=colored)
Ejemplo n.º 9
0
def main():
    # --------- 1. get image path and name ---------
    model_name = 'u2net'
    cwd = Path(os.getcwd())
    image_dir = cwd / 'test_data' / 'test_human_images'
    prediction_dir = cwd / 'test_data' / 'test_human_images_results'
    prediction_dir.mkdir(exist_ok=True)
    model_dir = cwd / 'saved_models' / (model_name + '_human_seg') / (model_name + '_human_seg.pth')

    img_name_list = list(image_dir.glob('*'))
    print("Images in test:", len(img_name_list))

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    print("...load U2NET---173.6 MB")
    net = U2NET(3, 1)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.load_state_dict(torch.load(model_dir, map_location=device)).to(device)
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):
        image_path = img_name_list[i_test]
        print("inferencing:", image_path.name)

        inputs_test = data_test['image']
        inputs_test = inputs_test.to(next(net.parameters()))
        with torch.no_grad():
            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

            # normalization
            pred = d1[:, 0, :, :]
            pred = normPRED(pred)

            # save results to test_results folder
            save_output(image_path, pred, prediction_dir)

            del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 10
0
def main():
    # --------- 1. get image path and name ---------
    model_name = 'u2netp'  # u2netp

    image_dir = '../train2014'
    prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
    model_dir = '../models/' + model_name + '.pth'

    img_name_list = glob.glob(image_dir + os.sep + '*')

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    # if torch.cuda.is_available():
    #     net.cuda()
    net.eval()
    all_out = {}
    for i_test, data_test in tqdm(enumerate(test_salobj_dataloader)):
        sep_ = img_name_list[i_test].split(os.sep)[-1]
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)
        #
        # if torch.cuda.is_available():
        #     inputs_test = Variable(inputs_test.cuda())
        # else:
        inputs_test = Variable(inputs_test)
        d = net(inputs_test)

        pred = normPRED(d)
        all_out[sep_] = pred

    pickle.dump(all_out, open("../data/coco_train_u2net.pik", "wb"), protocol=2)
Ejemplo n.º 11
0
    def __init__(self, in_ch: int, out_ch: int, lr: float,
                 pytorch_pretrained_model: str):
        super().__init__()

        self.save_hyperparameters()

        self.model = U2NET(in_ch, out_ch)
        self.lr = lr

        #self.bce_loss = nn.BCELoss(size_average=True)
        self.bce_loss = nn.BCEWithLogitsLoss(size_average=True)

        self.pretrained_path = pytorch_pretrained_model

        # Validation Metrics

        self.iou = JaccardIndex(num_classes=2)
Ejemplo n.º 12
0
    def __init__(self, model_name):
        self.model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                                      model_name + '.pth')

        if model_name == 'u2net':
            print("...load U2NET---173.6 MB")
            net = U2NET(3, 1)
        elif model_name == 'u2netp':
            print("...load U2NEP---4.7 MB")
            net = U2NETP(3, 1)

        net.load_state_dict(
            torch.load(self.model_dir, map_location=torch.device('cpu')))
        if torch.cuda.is_available():
            net.cuda()
        net.eval()

        self.net = net
Ejemplo n.º 13
0
def model(model_name='u2net'):

    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))

    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    return net
Ejemplo n.º 14
0
def model(model_name="u2net"):

    model_dir = os.path.join(os.getcwd(), "saved_models", model_name,
                             model_name + ".pth")

    if model_name == "u2net":
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif model_name == "u2netp":
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))

    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    return net
Ejemplo n.º 15
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net_portrait' #u2netp

    image_dir = './test_data/test_portrait_images/portrait_im'
    prediction_dir = './test_data/test_portrait_images/portrait_results'
    if(not os.path.exists(prediction_dir)):
        os.mkdir(prediction_dir)
    model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'
    img_name_list = glob.glob(image_dir+'/*')
    print("Number of images: ", len(img_name_list))

    # --------- 2. dataloader ---------

    test_salobj_dataset = sal_generator(batch_size=1,
                                        img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(512),
                                                                      ToTensorLab(flag=0)])
                                        )

    # --------- 3. model define ---------

    print("...load U2NET---173.6 MB")
    net = U2NET(3,1)

    net.load(model_dir)

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataset):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        d1,d2,d3,d4,d5,d6,d7= net(data_test, steps=1)

        # normalization
        pred = 1.0 - d1[:,0,:,:]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test],pred,prediction_dir)

        del d1,d2,d3,d4,d5,d6,d7
Ejemplo n.º 16
0
def remove_bg(images):
    dataset = SalObjDataset(img_name_list=images,
                            lbl_name_list=[],
                            transform=transforms.Compose(
                                [RescaleT(320),
                                 ToTensorLab(flag=0)]))
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    net = U2NET(3, 1)

    net.load_state_dict(torch.load(model_dir, map_location='cpu'))

    net.eval()

    outputs = []
    for i, data in enumerate(dataloader):

        inputs = data['image']
        inputs = inputs.type(torch.FloatTensor)
        inputs = Variable(inputs)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs)

        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        filename = save_output(images[i], pred, output_dir)
        outputs.append(filename)

        img = cv2.imread(images[i])

        mask = cv2.imread(filename, 0)

        rgba = cv2.cvtColor(img, cv2.COLOR_RGB2RGBA)
        rgba[:, :, 3] = mask

        cv2.imwrite(filename, rgba)

        del d1, d2, d3, d4, d5, d6, d7

    return outputs
Ejemplo n.º 17
0
def main():

    # get the image path list for inference
    im_list = glob('./test_data/test_portrait_images/your_portrait_im/*')
    print("Number of images: ", len(im_list))
    # indicate the output directory
    out_dir = './test_data/test_portrait_images/your_portrait_results'
    if (not os.path.exists(out_dir)):
        os.mkdir(out_dir)

    # Load the cascade face detection model
    face_cascade = cv2.CascadeClassifier(
        './saved_models/face_detection_cv2/haarcascade_frontalface_default.xml'
    )
    # u2net_portrait path
    model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'

    # load u2net_portrait model
    net = U2NET(3, 1)
    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
    else:
        net.load_state_dict(
            torch.load(model_dir, map_location=torch.device('cpu')))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # do the inference one-by-one
    for i in range(0, len(im_list)):
        print("--------------------------")
        print("inferencing ", i, "/", len(im_list), im_list[i])

        # load each image
        img = cv2.imread(im_list[i])
        height, width = img.shape[0:2]
        face = detect_single_face(face_cascade, img)
        im_face = crop_face(img, face)
        im_portrait = inference(net, im_face)

        # save the output
        cv2.imwrite(out_dir + "/" + im_list[i].split('/')[-1][0:-4] + '.png',
                    (im_portrait * 255).astype(np.uint8))
Ejemplo n.º 18
0
def main():
    model_name='u2netp'

    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    img_name_list = glob.glob(image_dir + os.sep + '*')

    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                        ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)

    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    for data_test in test_salobj_dataloader:
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)
        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        dummy_input = inputs_test
        torch.onnx.export(net, dummy_input,"exported/onnx/{}.onnx".format(model_name),
        opset_version=10)
        
        break
Ejemplo n.º 19
0
def main():

    # get the image path list for inference
    im_list = glob('./facein/*')
    print("Number of images: ", len(im_list))
    # indicate the output directory
    out_dir = './faceout'
    if (not os.path.exists(out_dir)):
        os.mkdir(out_dir)

    # Load the cascade face detection model
    face_cascade = cv2.CascadeClassifier(
        './model/haarcascade_frontalface_default.xml')
    # u2net_portrait path
    model_dir = './model/u2net_portrait.pth'

    # load u2net_portrait model
    net = U2NET(3, 1)
    net.load_state_dict(torch.load(model_dir))
    print('loaded model')
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # do the inference one-by-one
    for i in range(0, len(im_list)):
        print("--------------------------")
        print("inferencing ", i, "/", len(im_list), im_list[i])

        # load each image
        img = cv2.imread(im_list[i])
        height, width = img.shape[0:2]
        face = detect_single_face(face_cascade, img)
        #print (face)
        im_portrait = crop_face(img, face)
        #im_face = crop_face(img, face)
        #im_portrait =  inference(net,im_face)

        # save the output
        cv2.imwrite(out_dir + "/" + im_list[i].split('/')[-1][0:-4] + '.png',
                    im_portrait)
Ejemplo n.º 20
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'  # u2netp

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data',
                                  model_name + '_results' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------

    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))

    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)

    for i_test, data_test in enumerate(len(img_name_list)):
        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        d1, d2, d3, d4, d5, d6, d7 = net.predict(test_salobj_dataset, steps=1)
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)
        del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 21
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'rgbd_u2net'  #u2netp

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data',
                                  model_name + '_results' + os.sep)
    # model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '_s6457_2021-02-18_12_36_53u2net_6457_bce_itr_38000_train_0.107382_tar_0.009159.pth')
    model_dir = 'saved_models/rgbd_u2net/rgbd_u2net_s30857_2021-04-06_03_52_47/rgbd_u2net_30857_bce_itr_220000_train_0.089360_tar_0.008674.pth'
    # path_files = Path('/pool/2021-03-31_22-11-41')
    path_files = Path('/pool/2021-03-31_22-32-41')
    img_name_list1 = sorted([str(x) for x in path_files.rglob('**/rgb/*.png')])
    img_name_list1 = [
        str(x) for x in img_name_list1
        if '2021-03-31' in str(x) or '2021-04-01' in str(x)
    ]

    # path_files2 = Path('/dataset')
    # img_name_list2 = sorted([str(x) for x in path_files2.rglob('**/rgb/*.png') if not Path(str(x).replace('rgb','annotation')).exists()])
    # img_name_list2 = [str(x) for x in img_name_list2 if '2021-03-10' in str(x)]

    img_name_list = img_name_list1  #+ img_name_list2
    print('img len', len(img_name_list))
    depth_name_list = [
        x.replace('rgb', 'aligned_depth') for x in img_name_list
    ]
    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = RGBD_SalObjDataset(img_name_list=img_name_list,
                                             depth_name_list=depth_name_list,
                                             lbl_name_list=[],
                                             transform=transforms.Compose([
                                                 RGBD_RescaleT(320),
                                                 RGBD_ToTensorLab(flag=0)
                                             ]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'rgbd_u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(4, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(tqdm(test_salobj_dataloader)):

        # print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        depth = data_test['depth']
        inputs_test = torch.cat((inputs_test, torch.unsqueeze(depth, dim=1)),
                                dim=1)  # H x W x 4

        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)
        with torch.no_grad():
            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir, i_test)

        del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 22
0
salobj_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
        RescaleT(320),
        RandomCrop(288),
        ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True
                               #shuffle=False
                               , num_workers=0)

# ------- 3. define model --------
# define the net
#选择模型
if(model_name=='u2net'):
    net = U2NET(3, 1)
elif(model_name=='u2netp'):
    net = U2NETP(3,1)

if torch.cuda.is_available():
    net.cuda()

# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
Ejemplo n.º 23
0
def main():

    # --------- 1. get image path and name ---------
    #model_name='u2net'
    model_name = 'u2netp'

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data',
                                  model_name + '_results' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        #print("test", inputs_test.shape)
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)
        #print("pred",pred.shape)
        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 24
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'  #u2netp
    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data',
                                  model_name + '_results' + os.sep)
    # image_dir = os.path.join('/nfs/project/huxiaoliang/data/white_or_not/white_bg_image')
    # prediction_dir = os.path.join('/nfs/project/huxiaoliang/data/white_or_not/white_bg_image_pred'+ os.sep)
    model_dir = os.path.join('/nfs/private/modelfiles/u2net-saved_models',
                             model_name, model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    error = []
    for i_test, data_test in enumerate(test_salobj_dataloader):
        try:
            print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

            inputs_test = data_test['image']
            inputs_test = inputs_test.type(torch.FloatTensor)

            if torch.cuda.is_available():
                inputs_test = Variable(inputs_test.cuda())
            else:
                inputs_test = Variable(inputs_test)

            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

            # normalization
            pred = d1[:, 0, :, :]
            pred = normPRED(pred)

            # save results to test_results folder
            if not os.path.exists(prediction_dir):
                os.makedirs(prediction_dir, exist_ok=True)
            save_output(img_name_list[i_test], pred, prediction_dir)

            del d1, d2, d3, d4, d5, d6, d7
        except Exception as ex:
            traceback.print_exc()
            error.append(img_name_list[i_test].split(os.sep)[-1])
    print('异常数据:', error)
Ejemplo n.º 25
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2net'  #u2netp

    image_dir = "/home/vybt/Downloads/U2_Net_Test"
    prediction_dir = "/home/vybt/Downloads/u-2--bps-net-prediction"
    model_dir = '/media/vybt/DATA/SmartFashion/deep-learning-projects/U-2-Net/saved_models/u2net/_bps_bce_itr_300000_train_0.107041_tar_0.011690.pth'

    img_name_list = glob.glob(image_dir + os.sep + '*')
    # print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([RescaleT(320),
                                      ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if model_name == 'u2net':
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 8)
    elif model_name == 'u2netp':
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 8)

    net.load_state_dict(torch.load(model_dir))

    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("Inference: ", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        # print("inputs test: {}".format(inputs_test.shape))
        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 26
0
def main():

    # --------- 1. get image path and name ---------
    model_name = 'u2netp'  # u2netp u2net
    data_dir = '/data2/wangjiajie/datasets/scene_segment1023/u2data/'
    image_dir = os.path.join(data_dir, 'test_imgs')
    prediction_dir = os.path.join('./outputs/', model_name + '/')
    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir, exist_ok=True)
    # tra_label_dir = 'test_lbls/'

    image_ext = '.jpg'
    # label_ext = '.jpg' # '.png'
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name,
                             model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(f'test img numbers are: {len(img_name_list)}')

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=Compose([
                                            SmallestMaxSize(max_size=320),
                                        ]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)

    # net.load_state_dict(torch.load(model_dir))
    checkpoint = torch.load(model_dir)
    d = collections.OrderedDict()
    for key, value in checkpoint.items():
        tmp = key[7:]
        d[tmp] = value
    net.load_state_dict(d)
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = 1.0 - d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 27
0
def main():
    name = 'test'
    # please input the height  unit:M
    body_height = 1.63

    #"Input image" path
    image_dir = os.path.join(os.getcwd(), 'input')

    #"Output model" path
    outbody_filenames = './output/{}.obj'.format(name)

    #########################################################
    #this code used for image segmentation to remove the background to get Silhouettes
    # --------- 1. get image path and name ---------

    model_name='u2net'#u2net or u2netp

    #set orignal silhouette images path
    prediction_dir1 = os.path.join(os.getcwd(), 'Silhouette' + os.sep)

    #set the path of silhouette images after horizontal flippath
    prediction_dir = os.path.join(os.getcwd(), 'test_data' + os.sep)

    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------

    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. silhouette cutting model define ---------
    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

        # normalization
        pred = d1[:,0,:,:]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test],pred,prediction_dir,prediction_dir1)

        del d1,d2,d3,d4,d5,d6,d7


    ###########################################################
    #this code used for reconstruct 3d model:

   #--------5.get the silhouette images --------
    img_filenames = ['./test_data/front.png', './test_data/side.png']

    # img = cv2.imread(img_filenames[1])
    # cv2.flip(img,1)



    # -----------6.load input data---------
    sampling_num = 648
    data = np.zeros([2, 2, sampling_num])
    for i in np.arange(len(img_filenames)):
        img = img_filenames[i]
        im = getBinaryimage(img, 600)  # deal with white-black image simply
        sample_points = getSamplePoints(im, sampling_num, i)
        center_p = np.mean(sample_points, axis=0)
        sample_points = sample_points - center_p
        data[i, :, :] = sample_points.T

    data = repeat_data(data)

    #--------7 load CNN model----reconstruct 3d body shape
    print('==> begining...')
    len_out = 22
    model_name = './Models/model.ckpt'
    ourModel = RegressionPCA(len_out)
    ourModel.load_state_dict(torch.load(model_name))
    ourModel.eval()

    #----------8 output results--------------
    save_obj(outbody_filenames, ourModel, body_height, data)
Ejemplo n.º 28
0
def main(model_name, img_dir, retrain, weight, model_dir):
    model_name = 'u2net' #'u2netp'
    #data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
    # tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
    # tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)
    tra_image_dir = os.path.join(img_dir, 'origin')
    tra_label_dir = os.path.join(img_dir, 'mask')
    # train_image_dir = os.path.join('')


    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)

    epoch_start = 0
    epoch_num = 500
    batch_size_train = 20
    batch_size_val = 1
    train_num = 4000
    val_num = 500

    # tra_img_name_list = glob.glob(tra_image_dir + '*' + image_ext)
    tra_img_name_list = os.listdir(tra_image_dir)
    for i,item in enumerate(tra_img_name_list):
        tra_img_name_list[i] = os.path.join(tra_image_dir, item)


    tra_lbl_name_list = os.listdir(tra_label_dir)
    for i,item in enumerate(tra_lbl_name_list):
        tra_lbl_name_list[i] = os.path.join(tra_label_dir, item)

    print(tra_img_name_list)
    # for img_path in tra_img_name_list:
    # 	img_name = img_path.split(os.sep)[-1]

    # 	aaa = img_name.split(".")
    # 	bbb = aaa[0:-1]
    # 	imidx = bbb[0]
    # 	for i in range(1,len(bbb)):
    # 		imidx = imidx + "." + bbb[i]

    # 	tra_lbl_name_list.append(tra_label_dir + imidx + label_ext)

    print("---")
    print("train images: ", len(tra_img_name_list))
    print("train labels: ", len(tra_lbl_name_list))
    print("---")

    train_num = len(tra_img_name_list)

    salobj_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
    RescaleT(320),
    RandomCrop(288),
    ToTensorLab(flag=0)]))
    salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=4)

    # ------- 3. define model --------
    # define the net
    if(model_name=='u2net'):
        net = U2NET(3, 1)

    elif(model_name=='u2netp'):
        net = U2NETP(3,1)

    if torch.cuda.is_available():
        net.cuda()

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    # ------- 5. training process --------
    print("---start training...")

    if retrain == True:
        checkpoint = torch.load(weight)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        # loss = checkpoint['loss']

    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 0
    save_frq = 2000 # save the model every 2000 iterations



    for epoch in range(0, epoch_num):
        net.train()
    
        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1

            ite_num4val = ite_num4val + 1
            # print(data)
            inputs, labels = data['image'], data['label']

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)

            # wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
                                                                                            requires_grad=False)
            else:
                inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

            # y zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
            loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

            loss.backward()
            optimizer.step()

            # # print statistics
            running_loss += loss.data.item()
            running_tar_loss += loss2.data.item()

            # del temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss

            print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
            epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))

            # if ite_num % save_frq == 0:
            #     # torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
            #     # torch.save({
            #     # 'epoch': epoch,
            #     # 'model_state_dict': net.state_dict(),
            #     # 'optimizer_state_dict': optimizer.state_dict(),
            #     # 'loss': loss,
            #     # }, model_dir + model_name + epoch)

            #     running_loss = 0.0
            #     running_tar_loss = 0.0
            #     net.train()  # resume train
            #     ite_num4val = 0
        torch.save(
        {
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, os.path.join(model_dir, model_name + str(epoch))     
        )
Ejemplo n.º 29
0
def main():
    # --------- 1. get image path and name ---------
    model_name = 'u2netp'  # u2net

    # image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    image_dir = '/home/hha/dataset/circle/circle'
    prediction_dir = '/home/hha/dataset/circle/circle_pred'
    model_dir = '/home/hha/pytorch_code/U-2-Net-master/saved_models/u2netp/u2netp.pthu2netp_bce_itr_2000_train_0.077763_tar_0.006976.pth'
    # model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')

    img_name_list = list(filter(lambda f: f.find('_mask') < 0, img_name_list))

    # print(img_name_list)

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose([
            RescaleT(320),  # 320
            ToTensorLab(flag=0)
        ]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=0)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = inputs_test.cuda()
        else:
            inputs_test = inputs_test

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        # pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Ejemplo n.º 30
0
def main():

    # ------- 2. set the directory of training dataset --------

    model_name = 'u2net'  #'u2netp'

    data_dir = './train_data/'
    tra_image_dir = 'DUTS/DUTS-TR/DUTS-TR/im_aug/'
    tra_label_dir = 'DUTS/DUTS-TR/DUTS-TR/gt_aug/'

    image_ext = '.jpg'
    label_ext = '.png'

    model_dir = './saved_models/' + model_name + '/'

    epoch_num = 100000
    batch_size_train = 12
    batch_size_val = 1
    train_num = 0
    val_num = 0

    tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)

    tra_lbl_name_list = []
    for img_path in tra_img_name_list:
        img_name = img_path.split("/")[-1]

        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1, len(bbb)):
            imidx = imidx + "." + bbb[i]

        tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)

    print("---")
    print("train images: ", len(tra_img_name_list))
    print("train labels: ", len(tra_lbl_name_list))
    print("---")

    train_num = len(tra_img_name_list)

    salobj_dataset = SalObjDataset(img_name_list=tra_img_name_list,
                                   lbl_name_list=tra_lbl_name_list,
                                   transform=transforms.Compose([
                                       RescaleT(320),
                                       RandomCrop(288),
                                       ToTensorLab(flag=0)
                                   ]))
    salobj_dataloader = DataLoader(salobj_dataset,
                                   batch_size=batch_size_train,
                                   shuffle=True,
                                   num_workers=1)

    # ------- 3. define model --------
    # define the net
    if (model_name == 'u2net'):
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        net = U2NETP(3, 1)

    if torch.cuda.is_available():
        net.cuda()

    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(),
                           lr=0.001,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)

    # ------- 5. training process --------
    print("---start training...")
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 0
    save_frq = 2000  # save the model every 2000 iterations

    for epoch in range(0, epoch_num):
        net.train()

        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1

            inputs, labels = data['image'], data['label']

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)

            # wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v = Variable(inputs.cuda(),
                                              requires_grad=False), Variable(
                                                  labels.cuda(),
                                                  requires_grad=False)
            else:
                inputs_v, labels_v = Variable(
                    inputs, requires_grad=False), Variable(labels,
                                                           requires_grad=False)

            # y zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
            loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6,
                                               labels_v)

            loss.backward()
            optimizer.step()

            # # print statistics
            running_loss += loss.data[0]
            running_tar_loss += loss2.data[0]

            # del temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss

            print(
                "[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f "
                % (epoch + 1, epoch_num,
                   (i + 1) * batch_size_train, train_num, ite_num,
                   running_loss / ite_num4val, running_tar_loss / ite_num4val))

            if ite_num % save_frq == 0:

                torch.save(
                    net.state_dict(), model_dir + model_name +
                    "_bce_itr_%d_train_%3f_tar_%3f.pth" %
                    (ite_num, running_loss / ite_num4val,
                     running_tar_loss / ite_num4val))
                running_loss = 0.0
                running_tar_loss = 0.0
                net.train()  # resume train
                ite_num4val = 0