Exemplo n.º 1
0
class Demo:
    def __init__(self, generator_state_pth):
        self.model = IPCGANs()
        state_dict = torch.load(generator_state_pth)
        self.model.load_generator_state_dict(state_dict)

    def demo(self, image, target=0):
        img_size = 400
        assert target < 5 and target >= 0, "label shoule be less than 5"

        transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize((img_size, img_size)),
            torchvision.transforms.ToTensor(),
            Img_to_zero_center()
        ])
        label_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])
        image = transforms(image).unsqueeze(0)
        full_one = np.ones((img_size, img_size), dtype=np.float32)
        full_zero = np.zeros((img_size, img_size, 5), dtype=np.float32)
        full_zero[:, :, target] = full_one
        label = label_transforms(full_zero).unsqueeze(0)

        img = image.cuda()
        lbl = label.cuda()
        self.model.cuda()

        res = self.model.test_generate(img, lbl)

        res = Reverse_zero_center()(res)
        res_img = res.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        return Image.fromarray((res_img * 255).astype(np.uint8))
 def __init__(self, generator_state_pth):
     self.model = IPCGANs()
     state_dict = torch.load(generator_state_pth)
     self.model.load_generator_state_dict(state_dict)
class Demo:
    def __init__(self, generator_state_pth):
        self.model = IPCGANs()
        state_dict = torch.load(generator_state_pth)
        self.model.load_generator_state_dict(state_dict)

    def mtcnn_align(self, image):
        dst = []
        src = np.array(
            [[30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366],
             [33.5493, 92.3655], [62.7299, 92.2041]],
            dtype=np.float32)
        threshold = [0.6, 0.7, 0.9]
        factor = 0.85
        minSize = 20
        imgSize = [120, 100]
        detector = MTCNN(steps_threshold=threshold,
                         scale_factor=factor,
                         min_face_size=minSize)
        keypoint_list = [
            'left_eye', 'right_eye', 'nose', 'mouth_left', 'mouth_right'
        ]

        npimage = np.array(image)
        dictface_list = detector.detect_faces(
            npimage
        )  # if more than one face is detected, [0] means choose the first face

        if len(dictface_list) > 1:
            boxs = []
            for dictface in dictface_list:
                boxs.append(dictface['box'])
            center = np.array(npimage.shape[:2]) / 2
            boxs = np.array(boxs)
            face_center_y = boxs[:, 0] + boxs[:, 2] / 2
            face_center_x = boxs[:, 1] + boxs[:, 3] / 2
            face_center = np.column_stack(
                (np.array(face_center_x), np.array(face_center_y)))
            distance = np.sqrt(np.sum(np.square(face_center - center), axis=1))
            min_id = np.argmin(distance)
            dictface = dictface_list[min_id]
        else:
            if len(dictface_list) == 0:
                return image
            else:
                dictface = dictface_list[0]
        face_keypoint = dictface['keypoints']
        for keypoint in keypoint_list:
            dst.append(face_keypoint[keypoint])
        dst = np.array(dst).astype(np.float32)
        tform = trans.SimilarityTransform()
        tform.estimate(dst, src)
        M = tform.params[0:2, :]
        warped = cv2.warpAffine(npimage,
                                M, (imgSize[1], imgSize[0]),
                                borderValue=0.0)
        warped = cv2.resize(warped, (400, 400))
        return Image.fromarray(warped.astype(np.uint8))

    def demo(self, image, target=0):
        image = self.mtcnn_align(image)
        assert target < 5 and target >= 0, "label shoule be less than 5"

        transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize((128, 128)),
            torchvision.transforms.ToTensor(),
            Img_to_zero_center()
        ])
        label_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])
        image = transforms(image).unsqueeze(0)
        full_one = np.ones((128, 128), dtype=np.float32)
        full_zero = np.zeros((128, 128, 5), dtype=np.float32)
        full_zero[:, :, target] = full_one
        label = label_transforms(full_zero).unsqueeze(0)

        img = image.cuda()
        lbl = label.cuda()
        self.model.cuda()

        res = self.model.test_generate(img, lbl)

        res = Reverse_zero_center()(res)
        res_img = res.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        return Image.fromarray((res_img * 255).astype(np.uint8))
Exemplo n.º 4
0
def main():
    logger.info("Start to train:\n arguments: %s" % str(args))
    #step3: define transform
    transforms = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
         Img_to_zero_center()])
    label_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])
    #step4: define train/test dataloader
    train_dataset = CACD("train", transforms, label_transforms)
    test_dataset = CACD("test", transforms, label_transforms)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    #step5: define model,optim
    model = IPCGANs(lr=args.learning_rate,
                    age_classifier_path=args.age_classifier_path,
                    gan_loss_weight=args.gan_loss_weight,
                    feature_loss_weight=args.feature_loss_weight,
                    age_loss_weight=args.age_loss_weight)
    #,feature_extractor_path=args.feature_extractor_path)
    model.load_generator_state_dict_custom(model_path="model_last.pth")
    d_optim = model.d_optim
    g_optim = model.g_optim

    samples_tqdm = tqdm(enumerate(train_loader, 1), position=0, leave=True)

    for epoch in range(args.max_epoches):
        for idx, (source_img_227,source_img_128,true_label_img,\
               true_label_128,true_label_64,fake_label_64, true_label) in samples_tqdm:

            running_d_loss = None
            running_g_loss = None
            n_iter = epoch * len(train_loader) + idx

            #mv to gpu
            source_img_227 = source_img_227.cuda()
            source_img_128 = source_img_128.cuda()
            true_label_img = true_label_img.cuda()
            true_label_128 = true_label_128.cuda()
            true_label_64 = true_label_64.cuda()
            fake_label_64 = fake_label_64.cuda()
            true_label = true_label.cuda()

            #train discriminator
            for d_iter in range(args.d_iter):
                #d_lr_scheduler.step()
                d_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                d_loss = model.d_loss
                running_d_loss = d_loss
                d_loss.backward()
                d_optim.step()

            #visualize params
            for name, param in model.discriminator.named_parameters():
                writer.add_histogram("discriminator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            #train generator
            for g_iter in range(args.g_iter):
                #g_lr_scheduler.step()
                g_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                g_loss = model.g_loss
                running_g_loss = g_loss
                g_loss.backward()
                g_optim.step()

            for name, param in model.generator.named_parameters():
                writer.add_histogram("generator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            format_str = ('step %d/%d, g_loss = %.3f, d_loss = %.3f')
            samples_tqdm.set_description(
                format_str %
                (idx, len(train_loader), running_g_loss, running_d_loss))

            writer.add_scalars('data/loss', {
                'G_loss': running_g_loss,
                'D_loss': running_d_loss
            }, n_iter)

            # save the parameters at the end of each save interval
            if idx % args.save_interval == 0:
                model.save_model(dir=args.saved_model_folder,
                                 filename='epoch_%d_iter_%d.pth' %
                                 (epoch, idx))
                model.save_model(dir="", filename="model_last.pth")
                logger.info('checkpoint has been created!')

            #val step
            if idx % args.val_interval == 0:
                save_dir = os.path.join(args.saved_validation_folder,
                                        "epoch_%d" % epoch, "idx_%d" % idx)
                check_dir(save_dir)
                for val_idx, (source_img_128,
                              true_label_128) in enumerate(tqdm(test_loader)):
                    save_image(Reverse_zero_center()(source_img_128),
                               fp=os.path.join(
                                   save_dir,
                                   "batch_%d_source.jpg" % (val_idx)))

                    pic_list = []
                    pic_list.append(source_img_128)
                    for age in range(args.age_groups):
                        img = model.test_generate(source_img_128,
                                                  true_label_128[age])
                        save_image(Reverse_zero_center()(img),
                                   fp=os.path.join(
                                       save_dir, "batch_%d_age_group_%d.jpg" %
                                       (val_idx, age)))
                logger.info('validation image has been created!')
Exemplo n.º 5
0
def main(args, logger, writer):
    logger.info("Start to train:\n arguments: %s" % str(args))
    content = "[202.*.*.150.] [INFO] Start to train - IPCGANs "
    payload = {"text": content}
    requests.post(webhook_url,
                  data=json.dumps(payload),
                  headers={'Content-Type': 'application/json'})

    #step3: define transform
    transforms = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
         Img_to_zero_center()])
    label_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])
    #step4: define train/test dataloader
    train_dataset = CACD("train", transforms, label_transforms)
    test_dataset = CACD("test", transforms, label_transforms)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    #step5: define model,optim
    model = IPCGANs(lr=args.learning_rate,
                    age_classifier_path=args.age_classifier_path,
                    gan_loss_weight=args.gan_loss_weight,
                    feature_loss_weight=args.feature_loss_weight,
                    age_loss_weight=args.age_loss_weight,
                    generator_path=args.g_checkpoint,
                    discriminator_path=args.d_checkpoint)
    #,feature_extractor_path=args.feature_extractor_path)
    d_optim = model.d_optim
    g_optim = model.g_optim

    for epoch in range(args.max_epoches):
        avr_d_loss = 0
        avr_g_loss = 0
        count = 0
        for idx, (source_img_227,source_img_128,true_label_img,\
               true_label_128,true_label_64,fake_label_64, true_label) in enumerate(train_loader,1):

            running_d_loss = None
            running_g_loss = None
            n_iter = epoch * len(train_loader) + idx

            #mv to gpu
            source_img_227 = source_img_227.cuda()
            source_img_128 = source_img_128.cuda()
            true_label_img = true_label_img.cuda()
            true_label_128 = true_label_128.cuda()
            true_label_64 = true_label_64.cuda()
            fake_label_64 = fake_label_64.cuda()
            true_label = true_label.cuda()

            #train discriminator
            for d_iter in range(args.d_iter):
                #d_lr_scheduler.step()
                d_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                d_loss = model.d_loss
                running_d_loss = d_loss
                d_loss.backward()
                d_optim.step()

            #visualize params
            for name, param in model.discriminator.named_parameters():
                writer.add_histogram("discriminator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            #train generator
            for g_iter in range(args.g_iter):
                #g_lr_scheduler.step()
                g_optim.zero_grad()
                model.train(source_img_227=source_img_227,
                            source_img_128=source_img_128,
                            true_label_img=true_label_img,
                            true_label_128=true_label_128,
                            true_label_64=true_label_64,
                            fake_label_64=fake_label_64,
                            age_label=true_label)
                g_loss = model.g_loss
                running_g_loss = g_loss
                g_loss.backward()
                g_optim.step()

            for name, param in model.generator.named_parameters():
                writer.add_histogram("generator:%s" % name,
                                     param.clone().cpu().detach().numpy(),
                                     n_iter)

            format_str = ('step %d/%d, g_loss = %.3f, d_loss = %.3f')
            logger.info(
                format_str %
                (idx, len(train_loader), running_g_loss, running_d_loss))

            writer.add_scalars('data/loss', {
                'G_loss': running_g_loss,
                'D_loss': running_d_loss
            }, n_iter)
            avr_g_loss += running_g_loss
            avr_d_loss += running_d_loss
            count += 1

            # save the parameters at the end of each save interval
            if idx % args.save_interval == 0:
                model.save_model(dir=args.saved_model_folder,
                                 filename='epoch_%d_iter_%d.pth' %
                                 (epoch, idx))
                logger.info('checkpoint has been created!')

            #val step
            if idx % args.val_interval == 0:
                save_dir = os.path.join(args.saved_validation_folder,
                                        "epoch_%d" % epoch, "idx_%d" % idx)
                check_dir(save_dir)
                for val_idx, (source_img_128,
                              true_label_128) in enumerate(tqdm(test_loader)):
                    save_image(Reverse_zero_center()(source_img_128),
                               filename=os.path.join(
                                   save_dir,
                                   "batch_%d_source.jpg" % (val_idx)))

                    pic_list = []
                    pic_list.append(source_img_128)
                    for age in range(args.age_groups):
                        img = model.test_generate(source_img_128,
                                                  true_label_128[age])
                        save_image(Reverse_zero_center()(img),
                                   filename=os.path.join(
                                       save_dir, "batch_%d_age_group_%d.jpg" %
                                       (val_idx, age)))
                        # if epoch % 3 == 0 and idx % 1000 == 0:
                        #     print(post_image(
                        #         filename=os.path.join(save_dir,"batch_%d_age_group_%d.jpg"%(val_idx,age)),
                        #         token='xoxp-66111612183-66103666016-826666478608-af2a1c301014db145d3cf92d02b9bdcf',
                        #         channels='CPX0UMK42'))
                logger.info('validation image has been created!')
        avr_d_loss = avr_d_loss / count
        avr_g_loss = avr_g_loss / count

        content = "[202.*.*.150.]  [INFO] Epoch End : " + str(
            epoch) + ", d_loss : " + str(avr_d_loss) + ", g_loss : " + str(
                avr_g_loss)
        payload = {"text": content}
        requests.post(webhook_url,
                      data=json.dumps(payload),
                      headers={'Content-Type': 'application/json'})