Пример #1
0
    def forward(self,
                pred,
                target,
                target_parse,
                masksampled,
                gram,
                nearest,
                use_l1=True):

        weight = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4,
                  1.0]  # more high level info
        weight.reverse()
        loss = 0
        # print(self.slice[0](pred).shape)
        if gram:
            loss_conv12 = self.lossmse(self.gram(self.slice[0](pred)),
                                       self.gram(self.slice[0](target)))
        elif nearest:
            loss_conv12 = self.nnloss(self.slice[0](pred),
                                      self.slice[0](target))
        else:
            loss_conv12 = self.loss(self.slice[0](pred), self.slice[0](target))
            # reference, predicted = self.loss(self.norm(self.slice[0](pred)), self.norm(self.slice[0](target)))
            # abs = torch.abs(reference - predicted)
            # # sum along channels
            # norms = torch.sum(abs, dim=1)
            # # min over neighbourhood
            # loss,_ = torch.min(norms, dim=-1)
            # # loss = torch.sum(loss)/self.batch_size
            # loss_conv12 = torch.mean(loss)

        for i in range(5):
            if not masksampled:
                if gram:
                    gram_pred = self.gram(self.slice[i](pred))
                    gram_target = self.gram(self.slice[i](target))
                else:
                    gram_pred = self.slice[i](pred)
                    gram_target = self.slice[i](target)
                if use_l1:
                    loss = loss + weight[i] * self.loss(gram_pred, gram_target)
                else:
                    loss = loss + weight[i] * self.lossmse(
                        gram_pred, gram_target)
            else:
                pred = create_part(pred, target_parse, 'cloth')
                target = create_part(pred, target_parse, 'cloth')
                if gram:
                    gram_pred = self.gram(self.slice[i](pred))
                    gram_target = self.gram(self.slice[i](target))
                else:
                    gram_pred = self.slice[i](pred)
                    gram_target = self.slice[i](target)
                if use_l1:
                    loss = loss + weight[i] * self.loss(gram_pred, gram_target)
                else:
                    loss = loss + weight[i] * self.lossmse(
                        gram_pred, gram_target)
        return loss, loss_conv12
    def forward(self, opt):
        self.t4 = time()

        if self.train_mode == 'gmm':
            self.grid, self.theta = self.gmm_model(self.agnostic, self.cloth_image)
            self.warped_cloth_predict = F.grid_sample(self.cloth_image, self.grid)

        if opt.train_mode == 'parsing':
            self.fake_t = F.softmax(self.generator_parsing(self.input_parsing), dim=1)
            self.real_t = self.target_parse
        
        if opt.train_mode == 'appearance':
            generated_inter = self.generator_appearance(self.input_appearance)
            p_rendered, m_composite = torch.split(generated_inter, 3, 1) 
            p_rendered = F.tanh(p_rendered)
            self.m_composite = F.sigmoid(m_composite)
            p_tryon = self.warped_cloth * self.m_composite + p_rendered * (1 - self.m_composite)
            self.fake_t = p_tryon
            self.real_t = self.target_image

            if opt.joint_all:

                generate_face = create_part(self.fake_t, self.generated_parsing_argmax, 'face', False)
                generate_image_without_face = self.fake_t - generate_face

                real_s_face = create_part(self.source_image, self.source_parse, 'face', False)
                real_t_face = create_part(self.target_image, self.generated_parsing_argmax, 'face', False)
                input = torch.cat((real_s_face, generate_face), dim=1)

                fake_t_face = self.generator_face(input)
                ###residual learning
                r"""attention
                """
                # fake_t_face = create_part(fake_t_face, self.generated_parsing, 'face', False)
                # fake_t_face = generate_face + fake_t_face
                fake_t_face = create_part(fake_t_face, self.generated_parsing_argmax, 'face', False)
                ### fake image
                self.fake_t = generate_image_without_face + fake_t_face

        if opt.train_mode == 'face':
            self.fake_t = self.generator_face(self.input_face)
            
            if opt.face_residual:
                self.fake_t = create_part(self.fake_t, self.generated_parsing_face, 'face', False)
                self.fake_t = self.target_face_fake + self.fake_t
            
            self.fake_t = create_part(self.fake_t, self.generated_parsing_face, 'face', False)
            self.refined_image = self.generated_image_without_face + self.fake_t
            self.real_t = create_part(self.target_image, self.generated_parsing_face, 'face', False)

        self.t5 = time()
Пример #3
0
def forward(opt, paths, gpu_ids, refine_path):
    cudnn.enabled = True
    cudnn.benchmark = True
    opt.output_nc = 3
    opt.warp_cloth = False

    gmm = GMM(opt)
    gmm = torch.nn.DataParallel(gmm).cuda()

    # 'batch'
    generator_parsing = Define_G(opt.input_nc_G_parsing, opt.output_nc_parsing,
                                 opt.ndf, opt.netG_parsing, opt.norm,
                                 not opt.no_dropout, opt.init_type,
                                 opt.init_gain, opt.gpu_ids)

    generator_app_cpvton = Define_G(opt.input_nc_G_app,
                                    opt.output_nc_app,
                                    opt.ndf,
                                    opt.netG_app,
                                    opt.norm,
                                    not opt.no_dropout,
                                    opt.init_type,
                                    opt.init_gain,
                                    opt.gpu_ids,
                                    with_tanh=False)

    generator_face = Define_G(opt.input_nc_D_face, opt.output_nc_face, opt.ndf,
                              opt.netG_face, opt.norm, not opt.no_dropout,
                              opt.init_type, opt.init_gain, opt.gpu_ids)

    models = [gmm, generator_parsing, generator_app_cpvton, generator_face]
    for model, path in zip(models, paths):
        load_model(model, path)
    print('==>loaded model')

    augment = {}

    if '0.4' in torch.__version__:
        augment['3'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])  # change to [C, H, W]
        augment['1'] = augment['3']

    else:
        augment['3'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])  # change to [C, H, W]

        augment['1'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ])  # change to [C, H, W]

    val_dataset = DemoDataset(opt, augment=augment)
    val_dataloader = DataLoader(val_dataset,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers,
                                batch_size=opt.batch_size_v,
                                pin_memory=True)

    with torch.no_grad():
        for i, result in enumerate(val_dataloader):
            'warped cloth'
            warped_cloth = warped_image(gmm, result)
            if opt.warp_cloth:
                warped_cloth_name = result['warped_cloth_name']
                warped_cloth_path = os.path.join('dataset', 'warped_cloth',
                                                 warped_cloth_name[0])
                if not os.path.exists(os.path.split(warped_cloth_path)[0]):
                    os.makedirs(os.path.split(warped_cloth_path)[0])
                utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)
                print('processing_%d' % i)
                continue
            source_parse = result['source_parse'].float().cuda()
            target_pose_embedding = result['target_pose_embedding'].float(
            ).cuda()
            source_image = result['source_image'].float().cuda()
            cloth_parse = result['cloth_parse'].cuda()
            cloth_image = result['cloth_image'].cuda()
            target_pose_img = result['target_pose_img'].float().cuda()
            cloth_parse = result['cloth_parse'].float().cuda()
            source_parse_vis = result['source_parse_vis'].float().cuda()

            "filter add cloth infomation"
            real_s = source_parse
            index = [
                x for x in list(range(20)) if x != 5 and x != 6 and x != 7
            ]
            real_s_ = torch.index_select(real_s, 1, torch.tensor(index).cuda())
            input_parse = torch.cat(
                (real_s_, target_pose_embedding, cloth_parse), 1).cuda()

            'P'
            generate_parse = generator_parsing(input_parse)  # tanh
            generate_parse = F.softmax(generate_parse, dim=1)

            generate_parse_argmax = torch.argmax(generate_parse,
                                                 dim=1,
                                                 keepdim=True).float()
            res = []
            for index in range(20):
                res.append(generate_parse_argmax == index)
            generate_parse_argmax = torch.cat(res, dim=1).float()

            "A"
            image_without_cloth = create_part(source_image, source_parse,
                                              'image_without_cloth', False)
            input_app = torch.cat(
                (image_without_cloth, warped_cloth, generate_parse), 1).cuda()
            generate_img = generator_app_cpvton(input_app)
            p_rendered, m_composite = torch.split(generate_img, 3, 1)
            p_rendered = F.tanh(p_rendered)
            m_composite = F.sigmoid(m_composite)
            p_tryon = warped_cloth * m_composite + \
                p_rendered * (1 - m_composite)
            refine_img = p_tryon

            "F"
            generate_face = create_part(refine_img, generate_parse_argmax,
                                        'face', False)
            generate_img_without_face = refine_img - generate_face
            source_face = create_part(source_image, source_parse, 'face',
                                      False)
            input_face = torch.cat((source_face, generate_face), 1)
            fake_face = generator_face(input_face)
            fake_face = create_part(fake_face, generate_parse_argmax, 'face',
                                    False)
            refine_img = generate_img_without_face + fake_face

            "generate parse vis"
            if opt.save_time:
                generate_parse_vis = source_parse_vis
            else:
                generate_parse_vis = torch.argmax(generate_parse,
                                                  dim=1,
                                                  keepdim=True).permute(
                                                      0, 2, 3, 1).contiguous()
                generate_parse_vis = pose_utils.decode_labels(
                    generate_parse_vis)
            "save results"
            images = [
                source_image, cloth_image, target_pose_img, warped_cloth,
                source_parse_vis, generate_parse_vis, p_tryon, refine_img
            ]
            pose_utils.save_img(images,
                                os.path.join(refine_path, '%d.jpg') % (i))

    torch.cuda.empty_cache()
    def set_input(self, opt, result):

        self.t2 = time()

        self.source_pose_embedding = result['source_pose_embedding'].float(
        ).cuda()
        self.target_pose_embedding = result['target_pose_embedding'].float(
        ).cuda()
        self.source_image = result['source_image'].float().cuda()
        self.target_image = result['target_image'].float().cuda()
        self.source_parse = result['source_parse'].float().cuda()
        self.target_parse = result['target_parse'].float().cuda()
        self.cloth_image = result['cloth_image'].float().cuda()
        self.cloth_parse = result['cloth_parse'].float().cuda()
        self.warped_cloth = result['warped_cloth_image'].float().cuda(
        )  # preprocess warped image from gmm model
        self.target_parse_cloth = result['target_parse_cloth'].float().cuda()
        self.target_pose_img = result['target_pose_img'].float().cuda()
        self.image_without_cloth = create_part(self.source_image,
                                               self.source_parse,
                                               'image_without_cloth', False)

        self.im_c = result['im_c'].float().cuda()  # target warped cloth

        index = [x for x in list(range(20)) if x != 5 and x != 6 and x != 7]
        real_s_ = torch.index_select(self.source_parse, 1,
                                     torch.tensor(index).cuda())
        self.input_parsing = torch.cat(
            (real_s_, self.target_pose_embedding, self.cloth_parse), 1).cuda()

        if opt.train_mode == 'gmm':
            self.im_h = result['im_h'].float().cuda()
            self.source_parse_shape = result['source_parse_shape'].float(
            ).cuda()
            self.agnostic = torch.cat((self.source_parse_shape, self.im_h,
                                       self.target_pose_embedding),
                                      dim=1)

        elif opt.train_mode == 'parsing':
            self.real_s = self.input_parsing
            self.source_parse_vis = result['source_parse_vis'].float().cuda()
            self.target_parse_vis = result['target_parse_vis'].float().cuda()

        elif opt.train_mode == 'appearance':

            if opt.joint_all:
                self.generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
            else:
                with torch.no_grad():
                    self.generated_parsing = F.softmax(
                        self.generator_parsing(self.input_parsing), 1)
            self.input_appearance = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 self.generated_parsing), 1).cuda()

            "attention please"
            generated_parsing_ = torch.argmax(self.generated_parsing,
                                              1,
                                              keepdim=True)
            self.generated_parsing_argmax = torch.Tensor()

            for _ in range(20):
                self.generated_parsing_argmax = torch.cat([
                    self.generated_parsing_argmax.float().cuda(),
                    (generated_parsing_ == _).float()
                ],
                                                          dim=1)
            self.warped_cloth_parse = (
                (generated_parsing_ == 5) + (generated_parsing_ == 6) +
                (generated_parsing_ == 7)).float().cuda()

            if opt.save_time:
                self.generated_parsing_vis = torch.Tensor([0]).expand_as(
                    self.target_image)
            else:
                # decode labels cost much time
                _generated_parsing = torch.argmax(self.generated_parsing,
                                                  1,
                                                  keepdim=True)
                _generated_parsing = _generated_parsing.permute(
                    0, 2, 3, 1).contiguous().int()
                self.generated_parsing_vis = pose_utils.decode_labels(
                    _generated_parsing)  #array

            self.real_s = self.source_image

        elif opt.train_mode == 'face':
            if opt.joint_all:  # opt.joint
                generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
                self.generated_parsing_face = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
            else:
                generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)

                "attention please"
                generated_parsing_ = torch.argmax(generated_parsing,
                                                  1,
                                                  keepdim=True)
                self.generated_parsing_argmax = torch.Tensor()

                for _ in range(20):
                    self.generated_parsing_argmax = torch.cat([
                        self.generated_parsing_argmax.float().cuda(),
                        (generated_parsing_ == _).float()
                    ],
                                                              dim=1)

                # self.generated_parsing_face = generated_parsing_c
                self.generated_parsing_face = self.target_parse

            self.input_appearance = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 generated_parsing), 1).cuda()

            with torch.no_grad():
                self.generated_inter = self.generator_appearance(
                    self.input_appearance)
                p_rendered, m_composite = torch.split(self.generated_inter, 3,
                                                      1)
                p_rendered = F.tanh(p_rendered)
                m_composite = F.sigmoid(m_composite)
                self.generated_image = self.warped_cloth * m_composite + p_rendered * (
                    1 - m_composite)

            self.source_face = create_part(self.source_image,
                                           self.source_parse, 'face', False)
            self.target_face_real = create_part(self.target_image,
                                                self.generated_parsing_face,
                                                'face', False)
            self.target_face_fake = create_part(self.generated_image,
                                                self.generated_parsing_face,
                                                'face', False)
            self.generated_image_without_face = self.generated_image - self.target_face_fake

            self.input_face = torch.cat(
                (self.source_face, self.target_face_fake), 1).cuda()
            self.real_s = self.source_face

        elif opt.train_mode == 'joint':
            self.input_joint = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 self.generated_parsing), 1).cuda()

        self.t3 = time()
    def set_input(self, opt, result):

        self.t2 = time()

        # Input data returned by dataloader
        self.source_pose_embedding = result['source_pose_embedding'].float(
        ).cuda()
        self.target_pose_embedding = result['target_pose_embedding'].float(
        ).cuda()
        self.source_densepose_data = result['source_densepose_data'].float(
        ).cuda()
        self.target_densepose_data = result['target_densepose_data'].float(
        ).cuda()
        self.source_image = result['source_image'].float().cuda()
        self.target_image = result['target_image'].float().cuda()
        self.source_parse = result['source_parse'].float().cuda()
        self.target_parse = result['target_parse'].float().cuda()
        self.cloth_image = result['cloth_image'].float().cuda()
        self.cloth_parse = result['cloth_parse'].float().cuda()
        # self.warped_cloth = result['warped_cloth_image'].float().cuda() # preprocess warped image from gmm model
        self.target_parse_cloth = result['target_parse_cloth'].float().cuda()
        self.target_pose_img = result['target_pose_img']
        self.image_without_cloth = create_part(self.source_image,
                                               self.source_parse,
                                               'image_without_cloth', False)
        self.im_c = result['im_c'].float().cuda()  # target warped cloth

        # input_parsing input to the parsing transformation network
        index = [x for x in list(range(20)) if x != 5 and x != 6 and x != 7]
        real_s_ = torch.index_select(self.source_parse, 1,
                                     torch.tensor(index).cuda())
        self.input_parsing = torch.cat(
            (real_s_, self.target_densepose_data, self.cloth_parse), 1).cuda()

        if opt.train_mode != 'parsing' and opt.train_mode != 'gmm':
            self.warped_cloth = warped_image(self.gmm_model, result)

        ######################################
        # Part 1 GMM
        ######################################
        # For GMM training we need agnostic cloth_represent(source_head, densepose) original_cloth (from dataloader)
        if opt.train_mode == 'gmm':
            self.im_h = result['im_h'].float().cuda()
            self.source_parse_shape = result['source_parse_shape'].float(
            ).cuda()
            self.agnostic = torch.cat((self.source_parse_shape, self.im_h,
                                       self.target_pose_embedding),
                                      dim=1)

        ######################################
        # Part 2 PARSING
        ######################################
        # For parsing training
        # Input  input_parsing
        # output is the target parse
        elif opt.train_mode == 'parsing':
            self.real_s = self.input_parsing
            self.source_parse_vis = result['source_parse_vis'].float().cuda()
            self.target_parse_vis = result['target_parse_vis'].float().cuda()

        ######################################
        # Part 3 APPEARANCE
        ######################################
        # For appearance training
        # Input generated parse + warped_cloth + generated_parsing
        # Output corse render image(compare with target image) and composition mask (compare with warped_cloth_parse(this is generated from parsing network))

        elif opt.train_mode == 'appearance':

            # If join all training then train flow gradients else don't flow
            if opt.joint_all:
                self.generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
            else:
                with torch.no_grad():
                    self.generated_parsing = F.softmax(
                        self.generator_parsing(self.input_parsing), 1)

            # Input to the generated appearance network
            self.input_appearance = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 self.generated_parsing), 1).cuda()
            "attention please"
            generated_parsing_ = torch.argmax(self.generated_parsing,
                                              1,
                                              keepdim=True)

            # input to the generator appearance
            self.generated_parsing_argmax = torch.Tensor()

            # create the warped_cloth_parse from the parsing network
            for _ in range(20):
                self.generated_parsing_argmax = torch.cat([
                    self.generated_parsing_argmax.float().cuda(),
                    (generated_parsing_ == _).float()
                ],
                                                          dim=1)
            self.warped_cloth_parse = (
                (generated_parsing_ == 5) + (generated_parsing_ == 6) +
                (generated_parsing_ == 7)).float().cuda()

            # For visualization
            if opt.save_time:
                self.generated_parsing_vis = torch.Tensor([0]).expand_as(
                    self.target_image)
            else:
                # decode labels cost much time
                _generated_parsing = torch.argmax(self.generated_parsing,
                                                  1,
                                                  keepdim=True)
                _generated_parsing = _generated_parsing.permute(
                    0, 2, 3, 1).contiguous().int()
                self.generated_parsing_vis = pose_utils.decode_labels(
                    _generated_parsing)  # array

            # For gan training
            self.real_s = self.source_image

        ######################################
        # Part 4 FACE
        ######################################

        elif opt.train_mode == 'face':
            if opt.joint_all:
                generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
                self.generated_parsing_face = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
            else:
                generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)

                "attention please"
                generated_parsing_ = torch.argmax(generated_parsing,
                                                  1,
                                                  keepdim=True)
                self.generated_parsing_argmax = torch.Tensor()

                for _ in range(20):
                    self.generated_parsing_argmax = torch.cat([
                        self.generated_parsing_argmax.float().cuda(),
                        (generated_parsing_ == _).float()
                    ],
                                                              dim=1)

                # self.generated_parsing_face = generated_parsing_c
                self.generated_parsing_face = self.target_parse

            self.input_appearance = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 generated_parsing), 1).cuda()

            with torch.no_grad():
                self.generated_inter = self.generator_appearance(
                    self.input_appearance)
                p_rendered, m_composite = torch.split(self.generated_inter, 3,
                                                      1)
                p_rendered = F.tanh(p_rendered)
                m_composite = F.sigmoid(m_composite)
                self.generated_image = self.warped_cloth * \
                    m_composite + p_rendered * (1 - m_composite)

            self.source_face = create_part(self.source_image,
                                           self.source_parse, 'face', False)
            self.target_face_real = create_part(self.target_image,
                                                self.generated_parsing_face,
                                                'face', False)
            self.target_face_fake = create_part(self.generated_image,
                                                self.generated_parsing_face,
                                                'face', False)
            self.generated_image_without_face = self.generated_image - self.target_face_fake

            self.input_face = torch.cat(
                (self.source_face, self.target_face_fake), 1).cuda()
            self.real_s = self.source_face

        self.t3 = time()
Пример #6
0
def upload():
    target = os.path.join(APP_ROOT)

    # create image directory if not found
    if not os.path.isdir(target):
        os.mkdir(target)

    # retrieve file from html file-picker
    upload = request.files.getlist("file")[0]
    print("File name: {}".format(upload.filename))
    filename = upload.filename

    # file support verification
    ext = os.path.splitext(filename)[1]
    if (ext == ".jpg") or (ext == ".jpeg") or (ext == ".png") or (ext == ".bmp"):
        print("File accepted")
    else:
        return render_template("error.html", message="The selected file is not supported"), 400

    # save file
    destination = "/".join([target, filename])
    upload.save("static/images/temp.jpg")    
    
    im = Image.open("static/images/temp.jpg")
    if im.mode in ("RGBA", "P"):
      im = im.convert("RGB")
    new_width = 192
    new_height = 256
    im = im.resize((new_width,new_height),Image.ANTIALIAS)
    im.save("dataset/cloth_image/dress.jpg")
     # load our serialized edge detector from disk
  

    # load the input image and grab its dimensions
    image = cv2.imread("dataset/cloth_image/dress.jpg")

    (H, W) = image.shape[:2]



    # construct a blob out of the input image for the Holistically-Nested
    # Edge Detector
    blob = cv2.dnn.blobFromImage(image, scalefactor=1.0, size=(W, H),
            mean=(104.00698793, 116.66876762, 122.67891434),
            swapRB=False, crop=False)

    # set the blob as the input to the network and perform a forward pass
    # to compute the edges
    print("[INFO] performing holistically-nested edge detection...")
    net.setInput(blob)
    hed = net.forward()
    hed = cv2.resize(hed[0, 0], (W, H))
    hed = (255 * hed).astype("uint8")

    # show the output edge detection results for Canny and
    # Holistically-Nested Edge Detection
    '''cv2.imshow("Input", image)
    cv2.imshow("Canny", canny)
    cv2.imshow("HED", hed)'''

    cv2.imwrite("dataset/cloth_mask/dress_mask.png",hed)
    
    augment = {}

    if '0.4' in torch.__version__:
        augment['3'] = transforms.Compose([
                                    # transforms.Resize(256),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
            ]) # change to [C, H, W]
        augment['1'] = augment['3']

    else:
        augment['3'] = transforms.Compose([
                                # transforms.Resize(256),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
        ]) # change to [C, H, W]

        augment['1'] = transforms.Compose([
                                # transforms.Resize(256),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
        ]) # change to [C, H, W]
    
    
    val_dataset = DemoDataset(opt, augment=augment)
    val_dataloader = DataLoader(
                    val_dataset,
                    shuffle=False,
                    drop_last=False,
                    num_workers=opt.num_workers,
                    batch_size = opt.batch_size_v,
                    pin_memory=True)
    
    with torch.no_grad():
        for i, result in enumerate(val_dataloader):
            'warped cloth'
            warped_cloth = warped_image(gmm, result) 
            if opt.warp_cloth:
                warped_cloth_name = result['warped_cloth_name']
                warped_cloth_path = os.path.join('dataset', 'warped_cloth', warped_cloth_name[0])
                if not os.path.exists(os.path.split(warped_cloth_path)[0]):
                    os.makedirs(os.path.split(warped_cloth_path)[0])
                utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)
                print('processing_%d'%i)
                continue 
            source_parse = result['source_parse'].float().cuda()
            target_pose_embedding = result['target_pose_embedding'].float().cuda()
            source_image = result['source_image'].float().cuda()
            cloth_parse = result['cloth_parse'].cuda()
            cloth_image = result['cloth_image'].cuda()
            target_pose_img = result['target_pose_img'].float().cuda()
            cloth_parse = result['cloth_parse'].float().cuda()
            source_parse_vis = result['source_parse_vis'].float().cuda()

            "filter add cloth infomation"
            real_s = source_parse   
            index = [x for x in list(range(20)) if x != 5 and x != 6 and x != 7]
            real_s_ = torch.index_select(real_s, 1, torch.tensor(index).cuda())
            input_parse = torch.cat((real_s_, target_pose_embedding, cloth_parse), 1).cuda()
            
            'P'
            generate_parse = generator_parsing(input_parse) # tanh
            generate_parse = F.softmax(generate_parse, dim=1)

            generate_parse_argmax = torch.argmax(generate_parse, dim=1, keepdim=True).float()
            res = []
            for index in range(20):
                res.append(generate_parse_argmax == index)
            generate_parse_argmax = torch.cat(res, dim=1).float()

            "A"
            image_without_cloth = create_part(source_image, source_parse, 'image_without_cloth', False)
            input_app = torch.cat((image_without_cloth , warped_cloth, generate_parse), 1).cuda()
            generate_img = generator_app_cpvton(input_app)
            p_rendered, m_composite = torch.split(generate_img, 3, 1) 
            p_rendered = F.tanh(p_rendered)
            m_composite = F.sigmoid(m_composite)
            p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
            refine_img = p_tryon

            "F"
            generate_face = create_part(refine_img, generate_parse_argmax, 'face', False)
         
          
            source_face = create_part(source_image, generate_parse_argmax, 'face', False)
            source_face_new = create_part(source_image, source_parse, 'face', False)
            input_face = torch.cat((source_face_new, generate_face), 1)
            fake_face = generator_face(input_face)
            fake_face = create_part(fake_face, generate_parse_argmax, 'face', False) 
            generate_img_without_face = refine_img - generate_face
                      
            refine_img =fake_face + generate_img_without_face
            "generate parse vis"
            if opt.save_time:
                generate_parse_vis = source_parse_vis
            else:
                generate_parse_vis = torch.argmax(generate_parse, dim=1, keepdim=True).permute(0,2,3,1).contiguous()
                generate_parse_vis = pose_utils.decode_labels(generate_parse_vis)
            "save results"
            images = [source_image, cloth_image, refine_img]
            pose_utils.save_img(images, os.path.join(refine_path, '%d.jpg')%(i))

    torch.cuda.empty_cache()
 
       
    #cv2.imwrite("static/images/temp.jpg", image_new)

    


    return send_image('0.jpg')
Пример #7
0
def upload():
    target = os.path.join(APP_ROOT)

    # create image directory if not found
    if not os.path.isdir(target):
        os.mkdir(target)

    # retrieve file from html file-picker
    upload = request.files.getlist("file")[0]
    print("File name: {}".format(upload.filename))
    filename = upload.filename

    # file support verification
    ext = os.path.splitext(filename)[1]
    if (ext == ".jpg") or (ext == ".jpeg") or (ext == ".png") or (ext
                                                                  == ".bmp"):
        print("File accepted")
    else:
        return render_template(
            "error.html", message="The selected file is not supported"), 400

    # save file
    destination = "/".join([target, filename])
    upload.save("static/images/temp.jpg")
    im = Image.open("static/images/temp.jpg")
    im.save("temp.jpg")
    subprocess.call(
        shlex.split('removebg --api-key  YeEiA6Sxr7ej1aznnERxguPc temp.jpg'))

    im = cv2.imread("temp-removebg.png", cv2.IMREAD_UNCHANGED)
    ret, mask = cv2.threshold(im[:, :, 3], 0, 255, cv2.THRESH_BINARY)
    cv2.imwrite("temp_mask.png", mask)

    im = Image.open("temp.jpg")
    if im.mode in ("RGBA", "P"):
        im = im.convert("RGB")
    new_width = 192
    new_height = 256
    im = im.resize((new_width, new_height), Image.ANTIALIAS)
    im.save("dataset/cloth_image/dress.jpg")
    # load our serialized edge detector from disk

    # load the input image and grab its dimensions
    im = Image.open("temp_mask.png")
    if im.mode in ("RGBA", "P"):
        im = im.convert("RGB")
    new_width = 192
    new_height = 256
    im = im.resize((new_width, new_height), Image.ANTIALIAS)
    im.save("dataset/cloth_mask/dress_mask.png")

    augment = {}

    if '0.4' in torch.__version__:
        augment['3'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])  # change to [C, H, W]
        augment['1'] = augment['3']

    else:
        augment['3'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])  # change to [C, H, W]

        augment['1'] = transforms.Compose([
            # transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ])  # change to [C, H, W]

    val_dataset = DemoDataset(opt, augment=augment)
    val_dataloader = DataLoader(val_dataset,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers,
                                batch_size=opt.batch_size_v,
                                pin_memory=True)

    with torch.no_grad():
        for i, result in enumerate(val_dataloader):
            'warped cloth'
            warped_cloth = warped_image(gmm, result)
            if opt.warp_cloth:
                warped_cloth_name = result['warped_cloth_name']
                warped_cloth_path = os.path.join('dataset', 'warped_cloth',
                                                 warped_cloth_name[0])
                if not os.path.exists(os.path.split(warped_cloth_path)[0]):
                    os.makedirs(os.path.split(warped_cloth_path)[0])
                utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)
                print('processing_%d' % i)
                continue
            source_parse = result['source_parse'].float().cuda()
            target_pose_embedding = result['target_pose_embedding'].float(
            ).cuda()
            source_image = result['source_image'].float().cuda()
            cloth_parse = result['cloth_parse'].cuda()
            cloth_image = result['cloth_image'].cuda()
            target_pose_img = result['target_pose_img'].float().cuda()
            cloth_parse = result['cloth_parse'].float().cuda()
            source_parse_vis = result['source_parse_vis'].float().cuda()

            "filter add cloth infomation"
            real_s = source_parse
            index = [
                x for x in list(range(20)) if x != 5 and x != 6 and x != 7
            ]
            real_s_ = torch.index_select(real_s, 1, torch.tensor(index).cuda())
            input_parse = torch.cat(
                (real_s_, target_pose_embedding, cloth_parse), 1).cuda()

            'P'
            generate_parse = generator_parsing(input_parse)  # tanh
            generate_parse = F.softmax(generate_parse, dim=1)

            generate_parse_argmax = torch.argmax(generate_parse,
                                                 dim=1,
                                                 keepdim=True).float()
            res = []
            for index in range(20):
                res.append(generate_parse_argmax == index)
            generate_parse_argmax = torch.cat(res, dim=1).float()

            "A"
            image_without_cloth = create_part(source_image, source_parse,
                                              'image_without_cloth', False)
            input_app = torch.cat(
                (image_without_cloth, warped_cloth, generate_parse), 1).cuda()
            generate_img = generator_app_cpvton(input_app)
            p_rendered, m_composite = torch.split(generate_img, 3, 1)
            p_rendered = F.tanh(p_rendered)
            m_composite = F.sigmoid(m_composite)
            p_tryon = warped_cloth * m_composite + p_rendered * (1 -
                                                                 m_composite)
            refine_img = p_tryon

            "F"
            generate_face = create_part(refine_img, generate_parse_argmax,
                                        'face', False)

            source_face = create_part(source_image, generate_parse_argmax,
                                      'face', False)
            source_face_new = create_part(source_image, source_parse, 'face',
                                          False)
            input_face = torch.cat((source_face_new, generate_face), 1)
            fake_face = generator_face(input_face)
            fake_face = create_part(fake_face, generate_parse_argmax, 'face',
                                    False)
            generate_img_without_face = refine_img - generate_face

            refine_img = source_face + generate_img_without_face
            "generate parse vis"
            if opt.save_time:
                generate_parse_vis = source_parse_vis
            else:
                generate_parse_vis = torch.argmax(generate_parse,
                                                  dim=1,
                                                  keepdim=True).permute(
                                                      0, 2, 3, 1).contiguous()
                generate_parse_vis = pose_utils.decode_labels(
                    generate_parse_vis)
            "save results"
            images = [source_image, cloth_image, refine_img]
            pose_utils.save_img(images,
                                os.path.join(refine_path, '%d.jpg') % (i))

    torch.cuda.empty_cache()

    #cv2.imwrite("static/images/temp.jpg", image_new)

    return send_image('0.jpg')