def save_result(self, opt, epoch, iteration):
        if opt.train_mode == 'gmm':
            images = [self.cloth_image, self.warped_cloth.detach(), self.im_c]

        if opt.train_mode == 'parsing':
            fake_t_vis = pose_utils.decode_labels(
                torch.argmax(self.fake_t, dim=1,
                             keepdim=True).permute(0, 2, 3, 1).contiguous())
            images = [
                self.source_parse_vis, self.target_parse_vis,
                self.target_pose_img, self.cloth_parse, fake_t_vis
            ]

        if opt.train_mode == 'appearance':
            images = [
                self.image_without_cloth, self.warped_cloth,
                self.warped_cloth_parse, self.target_image, self.cloth_image,
                self.generated_parsing_vis,
                self.fake_t.detach()
            ]

        if opt.train_mode == 'face':
            images = [
                self.generated_image.detach(),
                self.refined_image.detach(), self.source_image,
                self.target_image, self.real_t,
                self.fake_t.detach()
            ]

        pose_utils.save_img(
            images,
            os.path.join(self.vis_path,
                         str(epoch) + '_' + str(iteration) + '.jpg'))
Example #2
0
def forward(opt, paths, gpu_ids, refine_path):
    cudnn.enabled = True
    cudnn.benchmark = True
    opt.output_nc = 3
    opt.warp_cloth = False

    gmm = GMM(opt)
    gmm = torch.nn.DataParallel(gmm).cuda()

    # 'batch'
    generator_parsing = Define_G(opt.input_nc_G_parsing, opt.output_nc_parsing,
                                 opt.ndf, opt.netG_parsing, opt.norm,
                                 not opt.no_dropout, opt.init_type,
                                 opt.init_gain, opt.gpu_ids)

    generator_app_cpvton = Define_G(opt.input_nc_G_app,
                                    opt.output_nc_app,
                                    opt.ndf,
                                    opt.netG_app,
                                    opt.norm,
                                    not opt.no_dropout,
                                    opt.init_type,
                                    opt.init_gain,
                                    opt.gpu_ids,
                                    with_tanh=False)

    generator_face = Define_G(opt.input_nc_D_face, opt.output_nc_face, opt.ndf,
                              opt.netG_face, opt.norm, not opt.no_dropout,
                              opt.init_type, opt.init_gain, opt.gpu_ids)

    models = [gmm, generator_parsing, generator_app_cpvton, generator_face]
    for model, path in zip(models, paths):
        load_model(model, path)
    print('==>loaded model')

    augment = {}

    if '0.4' in torch.__version__:
        augment['3'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])  # change to [C, H, W]
        augment['1'] = augment['3']

    else:
        augment['3'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])  # change to [C, H, W]

        augment['1'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ])  # change to [C, H, W]

    val_dataset = DemoDataset(opt, augment=augment)
    val_dataloader = DataLoader(val_dataset,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers,
                                batch_size=opt.batch_size_v,
                                pin_memory=True)

    with torch.no_grad():
        for i, result in enumerate(val_dataloader):
            'warped cloth'
            warped_cloth = warped_image(gmm, result)
            if opt.warp_cloth:
                warped_cloth_name = result['warped_cloth_name']
                warped_cloth_path = os.path.join('dataset', 'warped_cloth',
                                                 warped_cloth_name[0])
                if not os.path.exists(os.path.split(warped_cloth_path)[0]):
                    os.makedirs(os.path.split(warped_cloth_path)[0])
                utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)
                print('processing_%d' % i)
                continue
            source_parse = result['source_parse'].float().cuda()
            target_pose_embedding = result['target_pose_embedding'].float(
            ).cuda()
            source_image = result['source_image'].float().cuda()
            cloth_parse = result['cloth_parse'].cuda()
            cloth_image = result['cloth_image'].cuda()
            target_pose_img = result['target_pose_img'].float().cuda()
            cloth_parse = result['cloth_parse'].float().cuda()
            source_parse_vis = result['source_parse_vis'].float().cuda()

            "filter add cloth infomation"
            real_s = source_parse
            index = [
                x for x in list(range(20)) if x != 5 and x != 6 and x != 7
            ]
            real_s_ = torch.index_select(real_s, 1, torch.tensor(index).cuda())
            input_parse = torch.cat(
                (real_s_, target_pose_embedding, cloth_parse), 1).cuda()

            'P'
            generate_parse = generator_parsing(input_parse)  # tanh
            generate_parse = F.softmax(generate_parse, dim=1)

            generate_parse_argmax = torch.argmax(generate_parse,
                                                 dim=1,
                                                 keepdim=True).float()
            res = []
            for index in range(20):
                res.append(generate_parse_argmax == index)
            generate_parse_argmax = torch.cat(res, dim=1).float()

            "A"
            image_without_cloth = create_part(source_image, source_parse,
                                              'image_without_cloth', False)
            input_app = torch.cat(
                (image_without_cloth, warped_cloth, generate_parse), 1).cuda()
            generate_img = generator_app_cpvton(input_app)
            p_rendered, m_composite = torch.split(generate_img, 3, 1)
            p_rendered = F.tanh(p_rendered)
            m_composite = F.sigmoid(m_composite)
            p_tryon = warped_cloth * m_composite + \
                p_rendered * (1 - m_composite)
            refine_img = p_tryon

            "F"
            generate_face = create_part(refine_img, generate_parse_argmax,
                                        'face', False)
            generate_img_without_face = refine_img - generate_face
            source_face = create_part(source_image, source_parse, 'face',
                                      False)
            input_face = torch.cat((source_face, generate_face), 1)
            fake_face = generator_face(input_face)
            fake_face = create_part(fake_face, generate_parse_argmax, 'face',
                                    False)
            refine_img = generate_img_without_face + fake_face

            "generate parse vis"
            if opt.save_time:
                generate_parse_vis = source_parse_vis
            else:
                generate_parse_vis = torch.argmax(generate_parse,
                                                  dim=1,
                                                  keepdim=True).permute(
                                                      0, 2, 3, 1).contiguous()
                generate_parse_vis = pose_utils.decode_labels(
                    generate_parse_vis)
            "save results"
            images = [
                source_image, cloth_image, target_pose_img, warped_cloth,
                source_parse_vis, generate_parse_vis, p_tryon, refine_img
            ]
            pose_utils.save_img(images,
                                os.path.join(refine_path, '%d.jpg') % (i))

    torch.cuda.empty_cache()
Example #3
0
    def save_result(self, test_data_loader, opt, epoch, iteration):
        for index, test_data in enumerate(test_data_loader):

            # set the data
            with torch.no_grad(
            ):  # TODO Should this be torch.no_grad() or this is unnecessary
                img_name = test_data['source_image_name'][0].split('.')[
                    0]  # save the image by it's name
                self.set_input(opt, test_data)
                # call forward mode
                self.forward(opt)

            ######################################
            # Part 1 GMM Results
            ######################################
            if opt.train_mode == 'gmm':
                images = [
                    self.source_image, self.cloth_image, self.im_c,
                    self.warped_cloth_predict.detach()
                ]

            ######################################
            # Part 2 PARSING Results
            ######################################
            if opt.train_mode == 'parsing':
                fake_t_vis = pose_utils.decode_labels(
                    torch.argmax(self.fake_t, dim=1,
                                 keepdim=True).permute(0, 2, 3,
                                                       1).contiguous())
                test_me = pose_utils.decode_labels(
                    torch.argmax(self.source_parse_tformed,
                                 dim=1,
                                 keepdim=True).permute(0, 2, 3,
                                                       1).contiguous())
                images = [test_me, self.target_parse_vis, fake_t_vis]
                # for i in images:
                #     print(i.is_cuda)

            ######################################
            # Part 3 APPEARANCE Results
            ######################################
            if opt.train_mode == 'appearance':
                images = [
                    self.image_without_cloth,
                    self.warped_cloth.detach(), self.warped_cloth_parse,
                    self.target_image, self.cloth_image,
                    self.generated_parsing_vis,
                    self.fake_t.detach()
                ]

            ######################################
            # Part 4 FACE Results
            ######################################
            if opt.train_mode == 'face':
                images = [
                    self.generated_image.detach(),
                    self.refined_image.detach(), self.source_image,
                    self.target_image, self.real_t,
                    self.fake_t.detach()
                ]

            pose_utils.save_img(
                images,
                os.path.join(
                    self.vis_path,
                    str(img_name) + '_' + str(epoch) + '_' + str(iteration) +
                    '.jpg'))
Example #4
0
def upload():
    target = os.path.join(APP_ROOT)

    # create image directory if not found
    if not os.path.isdir(target):
        os.mkdir(target)

    # retrieve file from html file-picker
    upload = request.files.getlist("file")[0]
    print("File name: {}".format(upload.filename))
    filename = upload.filename

    # file support verification
    ext = os.path.splitext(filename)[1]
    if (ext == ".jpg") or (ext == ".jpeg") or (ext == ".png") or (ext == ".bmp"):
        print("File accepted")
    else:
        return render_template("error.html", message="The selected file is not supported"), 400

    # save file
    destination = "/".join([target, filename])
    upload.save("static/images/temp.jpg")    
    
    im = Image.open("static/images/temp.jpg")
    if im.mode in ("RGBA", "P"):
      im = im.convert("RGB")
    new_width = 192
    new_height = 256
    im = im.resize((new_width,new_height),Image.ANTIALIAS)
    im.save("dataset/cloth_image/dress.jpg")
     # load our serialized edge detector from disk
  

    # load the input image and grab its dimensions
    image = cv2.imread("dataset/cloth_image/dress.jpg")

    (H, W) = image.shape[:2]



    # construct a blob out of the input image for the Holistically-Nested
    # Edge Detector
    blob = cv2.dnn.blobFromImage(image, scalefactor=1.0, size=(W, H),
            mean=(104.00698793, 116.66876762, 122.67891434),
            swapRB=False, crop=False)

    # set the blob as the input to the network and perform a forward pass
    # to compute the edges
    print("[INFO] performing holistically-nested edge detection...")
    net.setInput(blob)
    hed = net.forward()
    hed = cv2.resize(hed[0, 0], (W, H))
    hed = (255 * hed).astype("uint8")

    # show the output edge detection results for Canny and
    # Holistically-Nested Edge Detection
    '''cv2.imshow("Input", image)
    cv2.imshow("Canny", canny)
    cv2.imshow("HED", hed)'''

    cv2.imwrite("dataset/cloth_mask/dress_mask.png",hed)
    
    augment = {}

    if '0.4' in torch.__version__:
        augment['3'] = transforms.Compose([
                                    # transforms.Resize(256),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
            ]) # change to [C, H, W]
        augment['1'] = augment['3']

    else:
        augment['3'] = transforms.Compose([
                                # transforms.Resize(256),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
        ]) # change to [C, H, W]

        augment['1'] = transforms.Compose([
                                # transforms.Resize(256),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
        ]) # change to [C, H, W]
    
    
    val_dataset = DemoDataset(opt, augment=augment)
    val_dataloader = DataLoader(
                    val_dataset,
                    shuffle=False,
                    drop_last=False,
                    num_workers=opt.num_workers,
                    batch_size = opt.batch_size_v,
                    pin_memory=True)
    
    with torch.no_grad():
        for i, result in enumerate(val_dataloader):
            'warped cloth'
            warped_cloth = warped_image(gmm, result) 
            if opt.warp_cloth:
                warped_cloth_name = result['warped_cloth_name']
                warped_cloth_path = os.path.join('dataset', 'warped_cloth', warped_cloth_name[0])
                if not os.path.exists(os.path.split(warped_cloth_path)[0]):
                    os.makedirs(os.path.split(warped_cloth_path)[0])
                utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)
                print('processing_%d'%i)
                continue 
            source_parse = result['source_parse'].float().cuda()
            target_pose_embedding = result['target_pose_embedding'].float().cuda()
            source_image = result['source_image'].float().cuda()
            cloth_parse = result['cloth_parse'].cuda()
            cloth_image = result['cloth_image'].cuda()
            target_pose_img = result['target_pose_img'].float().cuda()
            cloth_parse = result['cloth_parse'].float().cuda()
            source_parse_vis = result['source_parse_vis'].float().cuda()

            "filter add cloth infomation"
            real_s = source_parse   
            index = [x for x in list(range(20)) if x != 5 and x != 6 and x != 7]
            real_s_ = torch.index_select(real_s, 1, torch.tensor(index).cuda())
            input_parse = torch.cat((real_s_, target_pose_embedding, cloth_parse), 1).cuda()
            
            'P'
            generate_parse = generator_parsing(input_parse) # tanh
            generate_parse = F.softmax(generate_parse, dim=1)

            generate_parse_argmax = torch.argmax(generate_parse, dim=1, keepdim=True).float()
            res = []
            for index in range(20):
                res.append(generate_parse_argmax == index)
            generate_parse_argmax = torch.cat(res, dim=1).float()

            "A"
            image_without_cloth = create_part(source_image, source_parse, 'image_without_cloth', False)
            input_app = torch.cat((image_without_cloth , warped_cloth, generate_parse), 1).cuda()
            generate_img = generator_app_cpvton(input_app)
            p_rendered, m_composite = torch.split(generate_img, 3, 1) 
            p_rendered = F.tanh(p_rendered)
            m_composite = F.sigmoid(m_composite)
            p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
            refine_img = p_tryon

            "F"
            generate_face = create_part(refine_img, generate_parse_argmax, 'face', False)
         
          
            source_face = create_part(source_image, generate_parse_argmax, 'face', False)
            source_face_new = create_part(source_image, source_parse, 'face', False)
            input_face = torch.cat((source_face_new, generate_face), 1)
            fake_face = generator_face(input_face)
            fake_face = create_part(fake_face, generate_parse_argmax, 'face', False) 
            generate_img_without_face = refine_img - generate_face
                      
            refine_img =fake_face + generate_img_without_face
            "generate parse vis"
            if opt.save_time:
                generate_parse_vis = source_parse_vis
            else:
                generate_parse_vis = torch.argmax(generate_parse, dim=1, keepdim=True).permute(0,2,3,1).contiguous()
                generate_parse_vis = pose_utils.decode_labels(generate_parse_vis)
            "save results"
            images = [source_image, cloth_image, refine_img]
            pose_utils.save_img(images, os.path.join(refine_path, '%d.jpg')%(i))

    torch.cuda.empty_cache()
 
       
    #cv2.imwrite("static/images/temp.jpg", image_new)

    


    return send_image('0.jpg')
Example #5
0
def upload():
    target = os.path.join(APP_ROOT)

    # create image directory if not found
    if not os.path.isdir(target):
        os.mkdir(target)

    # retrieve file from html file-picker
    upload = request.files.getlist("file")[0]
    print("File name: {}".format(upload.filename))
    filename = upload.filename

    # file support verification
    ext = os.path.splitext(filename)[1]
    if (ext == ".jpg") or (ext == ".jpeg") or (ext == ".png") or (ext
                                                                  == ".bmp"):
        print("File accepted")
    else:
        return render_template(
            "error.html", message="The selected file is not supported"), 400

    # save file
    destination = "/".join([target, filename])
    upload.save("static/images/temp.jpg")
    im = Image.open("static/images/temp.jpg")
    im.save("temp.jpg")
    subprocess.call(
        shlex.split('removebg --api-key  YeEiA6Sxr7ej1aznnERxguPc temp.jpg'))

    im = cv2.imread("temp-removebg.png", cv2.IMREAD_UNCHANGED)
    ret, mask = cv2.threshold(im[:, :, 3], 0, 255, cv2.THRESH_BINARY)
    cv2.imwrite("temp_mask.png", mask)

    im = Image.open("temp.jpg")
    if im.mode in ("RGBA", "P"):
        im = im.convert("RGB")
    new_width = 192
    new_height = 256
    im = im.resize((new_width, new_height), Image.ANTIALIAS)
    im.save("dataset/cloth_image/dress.jpg")
    # load our serialized edge detector from disk

    # load the input image and grab its dimensions
    im = Image.open("temp_mask.png")
    if im.mode in ("RGBA", "P"):
        im = im.convert("RGB")
    new_width = 192
    new_height = 256
    im = im.resize((new_width, new_height), Image.ANTIALIAS)
    im.save("dataset/cloth_mask/dress_mask.png")

    augment = {}

    if '0.4' in torch.__version__:
        augment['3'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])  # change to [C, H, W]
        augment['1'] = augment['3']

    else:
        augment['3'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])  # change to [C, H, W]

        augment['1'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ])  # change to [C, H, W]

    val_dataset = DemoDataset(opt, augment=augment)
    val_dataloader = DataLoader(val_dataset,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers,
                                batch_size=opt.batch_size_v,
                                pin_memory=True)

    with torch.no_grad():
        for i, result in enumerate(val_dataloader):
            'warped cloth'
            warped_cloth = warped_image(gmm, result)
            if opt.warp_cloth:
                warped_cloth_name = result['warped_cloth_name']
                warped_cloth_path = os.path.join('dataset', 'warped_cloth',
                                                 warped_cloth_name[0])
                if not os.path.exists(os.path.split(warped_cloth_path)[0]):
                    os.makedirs(os.path.split(warped_cloth_path)[0])
                utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)
                print('processing_%d' % i)
                continue
            source_parse = result['source_parse'].float().cuda()
            target_pose_embedding = result['target_pose_embedding'].float(
            ).cuda()
            source_image = result['source_image'].float().cuda()
            cloth_parse = result['cloth_parse'].cuda()
            cloth_image = result['cloth_image'].cuda()
            target_pose_img = result['target_pose_img'].float().cuda()
            cloth_parse = result['cloth_parse'].float().cuda()
            source_parse_vis = result['source_parse_vis'].float().cuda()

            "filter add cloth infomation"
            real_s = source_parse
            index = [
                x for x in list(range(20)) if x != 5 and x != 6 and x != 7
            ]
            real_s_ = torch.index_select(real_s, 1, torch.tensor(index).cuda())
            input_parse = torch.cat(
                (real_s_, target_pose_embedding, cloth_parse), 1).cuda()

            'P'
            generate_parse = generator_parsing(input_parse)  # tanh
            generate_parse = F.softmax(generate_parse, dim=1)

            generate_parse_argmax = torch.argmax(generate_parse,
                                                 dim=1,
                                                 keepdim=True).float()
            res = []
            for index in range(20):
                res.append(generate_parse_argmax == index)
            generate_parse_argmax = torch.cat(res, dim=1).float()

            "A"
            image_without_cloth = create_part(source_image, source_parse,
                                              'image_without_cloth', False)
            input_app = torch.cat(
                (image_without_cloth, warped_cloth, generate_parse), 1).cuda()
            generate_img = generator_app_cpvton(input_app)
            p_rendered, m_composite = torch.split(generate_img, 3, 1)
            p_rendered = F.tanh(p_rendered)
            m_composite = F.sigmoid(m_composite)
            p_tryon = warped_cloth * m_composite + p_rendered * (1 -
                                                                 m_composite)
            refine_img = p_tryon

            "F"
            generate_face = create_part(refine_img, generate_parse_argmax,
                                        'face', False)

            source_face = create_part(source_image, generate_parse_argmax,
                                      'face', False)
            source_face_new = create_part(source_image, source_parse, 'face',
                                          False)
            input_face = torch.cat((source_face_new, generate_face), 1)
            fake_face = generator_face(input_face)
            fake_face = create_part(fake_face, generate_parse_argmax, 'face',
                                    False)
            generate_img_without_face = refine_img - generate_face

            refine_img = source_face + generate_img_without_face
            "generate parse vis"
            if opt.save_time:
                generate_parse_vis = source_parse_vis
            else:
                generate_parse_vis = torch.argmax(generate_parse,
                                                  dim=1,
                                                  keepdim=True).permute(
                                                      0, 2, 3, 1).contiguous()
                generate_parse_vis = pose_utils.decode_labels(
                    generate_parse_vis)
            "save results"
            images = [source_image, cloth_image, refine_img]
            pose_utils.save_img(images,
                                os.path.join(refine_path, '%d.jpg') % (i))

    torch.cuda.empty_cache()

    #cv2.imwrite("static/images/temp.jpg", image_new)

    return send_image('0.jpg')