示例#1
0
    def onInitialize(self):
        exec(nnlib.import_all(), locals(), globals())
        self.set_vram_batch_requirements( {4:64} )

        self.resolution = 128
        self.face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF


        self.pose_est = PoseEstimator(self.resolution,
                                      FaceType.toString(self.face_type),
                                      load_weights=not self.is_first_run(),
                                      weights_file_root=self.get_model_root_path(),
                                      training=True)

        if self.is_training_mode:
            t = SampleProcessor.Types
            face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF

            self.set_training_data_generators ([
                    SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
                            sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
                            output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution, 'motion_blur':(25, 1) },
                                                  {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
                                                  {'types': (t.IMG_PITCH_YAW_ROLL,)}
                                                ]),

                    SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
                            sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
                            output_sample_types=[ {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':self.resolution },
                                                  {'types': (t.IMG_PITCH_YAW_ROLL,)}
                                                ])
                                            ])
示例#2
0
class Model(ModelBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args,
                         **kwargs,
                         ask_write_preview_history=False,
                         ask_target_iter=False,
                         ask_sort_by_yaw=False,
                         ask_random_flip=False,
                         ask_src_scale_mod=False)

    #override
    def onInitializeOptions(self, is_first_run, ask_override):
        default_face_type = 'f'
        if is_first_run:
            self.options['face_type'] = io.input_str(
                "Half or Full face? (h/f, ?:help skip:f) : ",
                default_face_type, ['h', 'f'],
                help_message=
                "Half face has better resolution, but covers less area of cheeks."
            ).lower()
        else:
            self.options['face_type'] = self.options.get(
                'face_type', default_face_type)

    #override
    def onInitialize(self):
        exec(nnlib.import_all(), locals(), globals())
        self.set_vram_batch_requirements({4: 32})

        self.resolution = 128
        self.face_type = FaceType.FULL if self.options[
            'face_type'] == 'f' else FaceType.HALF

        self.pose_est = PoseEstimator(
            self.resolution,
            FaceType.toString(self.face_type),
            load_weights=not self.is_first_run(),
            weights_file_root=self.get_model_root_path(),
            training=True)

        if self.is_training_mode:
            t = SampleProcessor.Types
            face_type = t.FACE_TYPE_FULL if self.options[
                'face_type'] == 'f' else t.FACE_TYPE_HALF

            self.set_training_data_generators([
                SampleGeneratorFace(
                    self.training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.batch_size,
                    generators_count=4,
                    sample_process_options=SampleProcessor.Options(
                        rotation_range=[0, 0]),  #random_flip=True,
                    output_sample_types=[{
                        'types':
                        (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE),
                        'resolution':
                        self.resolution,
                        'motion_blur': (25, 1)
                    }, {
                        'types': (t.IMG_PITCH_YAW_ROLL, )
                    }]),
                SampleGeneratorFace(
                    self.training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.batch_size,
                    generators_count=4,
                    sample_process_options=SampleProcessor.Options(
                        rotation_range=[0, 0]),  #random_flip=True,
                    output_sample_types=[{
                        'types':
                        (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE),
                        'resolution':
                        self.resolution
                    }, {
                        'types': (t.IMG_PITCH_YAW_ROLL, )
                    }])
            ])

    #override
    def onSave(self):
        self.pose_est.save_weights()

    #override
    def onTrainOneIter(self, generators_samples, generators_list):
        target_src, pitch_yaw_roll = generators_samples[0]

        pitch_loss, yaw_loss, roll_loss = self.pose_est.train_on_batch(
            target_src, pitch_yaw_roll)

        return (('pitch_loss', pitch_loss), ('yaw_loss', yaw_loss),
                ('roll_loss', roll_loss))

    #override
    def onGetPreview(self, generators_samples):
        test_src = generators_samples[0][0][0:4]  #first 4 samples
        test_pyr_src = generators_samples[0][1][0:4]
        test_dst = generators_samples[1][0][0:4]
        test_pyr_dst = generators_samples[1][1][0:4]

        h, w, c = self.resolution, self.resolution, 3
        h_line = 13

        result = []
        for name, img, pyr in [ ['training data', test_src, test_pyr_src],  \
                                ['evaluating data',test_dst, test_pyr_dst] ]:
            pyr_pred = self.pose_est.extract(img)

            hor_imgs = []
            for i in range(len(img)):
                img_info = np.ones((h, w, c)) * 0.1
                lines = ["%s" % (str(pyr[i])), "%s" % (str(pyr_pred[i]))]

                lines_count = len(lines)
                for ln in range(lines_count):
                    img_info[ ln*h_line:(ln+1)*h_line, 0:w] += \
                        imagelib.get_text_image (  (h_line,w,c), lines[ln], color=[0.8]*c )

                hor_imgs.append(
                    np.concatenate((img[i, :, :, 0:3], img_info), axis=1))

            result += [(name, np.concatenate(hor_imgs, axis=0))]

        return result
    def onInitialize(self, batch_size=-1, **in_options):

        exec(nnlib.code_import_all, locals(), globals())
        self.set_vram_batch_requirements({4: 4})

        resolution = self.options['resolution']
        bgr_shape = (resolution, resolution, 3)
        mask_shape = (resolution, resolution, 1)
        bgrm_shape = (resolution, resolution, 4)

        ngf = 64
        ndf = 64
        lambda_A = 100
        lambda_B = 100

        use_batch_norm = True  #created_batch_size > 1

        poseest = self.poseest = PoseEstimator(
            resolution, FaceType.toString(FaceType.FULL))

        self.enc = modelify(AVATARModel.DFEncFlow())([
            Input(bgr_shape),
            Input((poseest.class_nums[0], )),
            Input((poseest.class_nums[0], )),
            Input((poseest.class_nums[0], ))
        ])
        dec_Inputs = [Input(K.int_shape(x)[1:]) for x in self.enc.outputs]
        self.decA = modelify(AVATARModel.DFDecFlow(bgr_shape[2]))(dec_Inputs)
        self.decB = modelify(AVATARModel.DFDecFlow(bgr_shape[2]))(dec_Inputs)

        def GA(x):
            return self.decA(self.enc(x))

        self.GA = GA

        def GB(x):
            return self.decB(self.enc(x))

        self.GB = GB

        #self.GA = modelify(AVATARModel.ResNet (bgr_shape[2], use_batch_norm, n_blocks=6, ngf=ngf, use_dropout=True))( Input(bgr_shape) )
        #self.GB = modelify(AVATARModel.ResNet (bgr_shape[2], use_batch_norm, n_blocks=6, ngf=ngf, use_dropout=True))( Input(bgr_shape) )

        #self.GA = modelify(UNet (bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=ngf, use_dropout=True))(Input(bgr_shape))
        #self.GB = modelify(UNet (bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=ngf, use_dropout=True))(Input(bgr_shape))

        self.DA = modelify(
            AVATARModel.NLayerDiscriminator(use_batch_norm,
                                            ndf=ndf))(Input(bgr_shape))
        self.DB = modelify(
            AVATARModel.NLayerDiscriminator(use_batch_norm,
                                            ndf=ndf))(Input(bgr_shape))

        if not self.is_first_run():
            weights_to_load = [
                # (self.GA, 'GA.h5'),
                # (self.GB, 'GB.h5'),
                (self.enc, 'enc.h5'),
                (self.decA, 'decA.h5'),
                (self.decB, 'decB.h5'),
                (self.DA, 'DA.h5'),
                (self.DB, 'DB.h5'),
            ]
            self.load_weights_safe(weights_to_load)

        real_A0 = Input(bgr_shape)
        real_A0m = Input(mask_shape)
        real_B0 = Input(bgr_shape)
        real_B0m = Input(mask_shape)

        real_A0p, real_A0y, real_A0r = poseest.flow(real_A0)
        real_B0p, real_B0y, real_B0r = poseest.flow(real_B0)

        DA_ones = K.ones_like(K.shape(self.DA.outputs[0]))
        DA_zeros = K.zeros_like(K.shape(self.DA.outputs[0]))
        DB_ones = K.ones_like(K.shape(self.DB.outputs[0]))
        DB_zeros = K.zeros_like(K.shape(self.DB.outputs[0]))

        def DLoss(labels, logits):
            return K.mean(K.binary_crossentropy(labels, logits))

        def CycleLOSS(t1, t2):
            return dssim(kernel_size=int(resolution / 11.6),
                         max_value=2.0)(t1 + 1, t2 + 1)
            return K.mean(K.abs(t1 - t2))

        fake_B0 = self.GA([real_A0, real_B0p, real_B0y, real_B0r])
        fake_A0 = self.GB([real_B0, real_A0p, real_A0y, real_A0r])

        fake_B0p, fake_B0y, fake_B0r = poseest.flow(fake_B0)
        fake_A0p, fake_A0y, fake_A0r = poseest.flow(fake_A0)

        real_A0_d = self.DA(real_A0)
        real_A0_d_ones = K.ones_like(real_A0_d)

        fake_A0_d = self.DA(fake_A0)
        fake_A0_d_ones = K.ones_like(fake_A0_d)
        fake_A0_d_zeros = K.zeros_like(fake_A0_d)

        real_B0_d = self.DB(real_B0)
        real_B0_d_ones = K.ones_like(real_B0_d)

        fake_B0_d = self.DB(fake_B0)
        fake_B0_d_ones = K.ones_like(fake_B0_d)
        fake_B0_d_zeros = K.zeros_like(fake_B0_d)

        rec_A0 = self.GB([fake_B0, real_A0p, real_A0y, real_A0r])
        rec_B0 = self.GA([fake_A0, real_B0p, real_B0y, real_B0r])

        #import code
        #code.interact(local=dict(globals(), **locals()))

        loss_GA = DLoss(fake_B0_d_ones, fake_B0_d ) + \
                  lambda_A * 0.1 * K.mean(  K.square(fake_B0p-real_A0p) + K.square(fake_B0y-real_A0y) + K.square(fake_B0r-real_A0r)  ) + \
                  lambda_A * (CycleLOSS(rec_B0, real_B0) )

        weights_GA = self.enc.trainable_weights + self.decA.trainable_weights  # + #self.GA.trainable_weights

        loss_GB = DLoss(fake_A0_d_ones, fake_A0_d ) + \
                  lambda_B * 0.1 * K.mean( K.square(fake_A0p-real_B0p) + K.square(fake_A0y-real_B0y) + K.square(fake_A0r-real_B0r) ) + \
                  lambda_B * (CycleLOSS(rec_A0, real_A0) )

        weights_GB = self.enc.trainable_weights + self.decB.trainable_weights  # + #self.GB.trainable_weights

        def opt():
            return Adam(lr=2e-5, beta_1=0.5, beta_2=0.999,
                        tf_cpu_mode=2)  #, clipnorm=1)

        self.GA_train = K.function([real_A0, real_A0m, real_B0, real_B0m],
                                   [loss_GA],
                                   opt().get_updates(loss_GA, weights_GA))

        self.GB_train = K.function([real_A0, real_A0m, real_B0, real_B0m],
                                   [loss_GB],
                                   opt().get_updates(loss_GB, weights_GB))

        ###########

        loss_D_A = ( DLoss(real_A0_d_ones, real_A0_d ) + \
                          DLoss(fake_A0_d_zeros, fake_A0_d ) ) * 0.5

        self.DA_train = K.function(
            [real_A0, real_A0m, real_B0, real_B0m], [loss_D_A],
            opt().get_updates(loss_D_A, self.DA.trainable_weights))

        ############

        loss_D_B = ( DLoss(real_B0_d_ones, real_B0_d ) + \
                         DLoss(fake_B0_d_zeros, fake_B0_d ) ) * 0.5

        self.DB_train = K.function(
            [real_A0, real_A0m, real_B0, real_B0m], [loss_D_B],
            opt().get_updates(loss_D_B, self.DB.trainable_weights))

        ############

        self.G_view = K.function([real_A0, real_A0m, real_B0, real_B0m],
                                 [fake_A0, rec_A0, fake_B0, rec_B0])

        if self.is_training_mode:
            t = SampleProcessor.Types
            face_type = t.FACE_TYPE_FULL

            output_sample_types = [
                {
                    'types': (t.IMG_SOURCE, face_type, t.MODE_BGR),
                    'resolution': resolution
                },
                {
                    'types':
                    (t.IMG_SOURCE, face_type, t.MODE_M, t.FACE_MASK_FULL),
                    'resolution': resolution
                },
            ]

            self.set_training_data_generators([
                SampleGeneratorFace(
                    self.training_data_src_path,
                    debug=self.is_debug(),
                    batch_size=self.batch_size,
                    sample_process_options=SampleProcessor.Options(
                        random_flip=self.random_flip, normalize_tanh=True),
                    output_sample_types=output_sample_types),
                SampleGeneratorFace(
                    self.training_data_dst_path,
                    debug=self.is_debug(),
                    batch_size=self.batch_size,
                    sample_process_options=SampleProcessor.Options(
                        random_flip=self.random_flip, normalize_tanh=True),
                    output_sample_types=output_sample_types)
            ])
        else:
            self.G_convert = K.function([real_A0, real_B0m], [fake_B0])