Ejemplo n.º 1
0
    def build_model():
        input_image = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3],
                                      name='input_image')

        iuv_prediction = dm.networks.Hourglass(input_image,
                                               [INPUT_SHAPE, INPUT_SHAPE, 6],
                                               depth=4,
                                               batch_norm=True,
                                               use_coordconv=False)
        merged_inputs = dm.layers.Concatenate()([input_image, iuv_prediction])
        hm_prediction = dm.networks.Hourglass(merged_inputs,
                                              [INPUT_SHAPE, INPUT_SHAPE, 68],
                                              depth=4,
                                              batch_norm=True,
                                              use_coordconv=False)

        train_model = dm.DeepMachine(inputs=input_image,
                                     outputs=[iuv_prediction, hm_prediction])

        train_model.compile(
            optimizer=dm.optimizers.Adam(lr=LR),
            loss=[
                dm.losses.loss_iuv_regression,
                dm.losses.loss_heatmap_regression
            ],
        )

        return train_model
Ejemplo n.º 2
0
    def build_model():
        input_image = dm.layers.Input(shape=[INPUT_SHAPE,INPUT_SHAPE,6], name='input_image')
        ae_image, [vae_encoder, vae_decoder] = dm.networks.VAE(input_image, nf=64, depth=4, embedding=1024, latent=512, return_models=True)
        z_mean, z_log_var, _ = vae_encoder.outputs

        autoencoder_union = dm.DeepMachine(inputs=[input_image], outputs=[ae_image])

        def vae_loss(y_true, y_pred):
            reconstruction_loss = dm.losses.mse(y_true, y_pred)
            kl_loss = dm.losses.loss_kl(z_mean, z_log_var)
            return dm.K.mean(reconstruction_loss) + kl_loss

        autoencoder_union.compile(
            optimizer=dm.optimizers.Adam(lr=LR),
            loss=vae_loss
        )
        return autoencoder_union
Ejemplo n.º 3
0
    def model_builder():
        input_image = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3],
                                      name='input_image')
        iuv_prediction = dm.networks.Hourglass(
            input_image, [256, 256, 75],
            nf=64,
            batch_norm='InstanceNormalization2D')

        hm_prediction = dm.networks.Hourglass(
            input_image, [256, 256, 17],
            nf=64,
            batch_norm='InstanceNormalization2D')

        merged_inputs = dm.layers.Concatenate()(
            [input_image, hm_prediction, iuv_prediction])

        iuv_prediction_refine = dm.networks.Hourglass(
            merged_inputs, [256, 256, 75],
            nf=64,
            batch_norm='InstanceNormalization2D')

        hm_prediction_refine = dm.networks.Hourglass(
            merged_inputs, [256, 256, 17],
            nf=64,
            batch_norm='InstanceNormalization2D')

        train_model = dm.DeepMachine(inputs=input_image,
                                     outputs=[
                                         iuv_prediction, hm_prediction,
                                         iuv_prediction_refine,
                                         hm_prediction_refine
                                     ])
        train_model.compile(optimizer=dm.optimizers.Adam(lr=LR),
                            loss=[
                                dm.losses.loss_iuv_regression,
                                dm.losses.loss_heatmap_regression,
                                dm.losses.loss_iuv_regression,
                                dm.losses.loss_heatmap_regression
                            ],
                            loss_weights=[1, 1, 1, 1])

        return train_model
Ejemplo n.º 4
0
    def build_model():
        input_image = dm.layers.Input(
            shape=[INPUT_SHAPE, INPUT_SHAPE, INPUT_CHANNELS],
            name='input_image')

        embeding, softmax = dm.networks.ArcFace(
            [input_image],
            512,
            nf=NF,
            n_classes=N_CLASSES,
            batch_norm='BatchNormalization')

        train_model = dm.DeepMachine(inputs=[input_image],
                                     outputs=[embeding, softmax])

        n_gpu = len(FLAGS.gpu.split(','))
        if n_gpu > 1:
            train_model = multi_gpu_model(train_model, gpus=n_gpu)

        def arc_loss(y_true, y_pred, s=64., m1=1., m2=0.3, m3=0.):
            # arc feature
            arc = y_pred * y_true
            arc = tf.acos(arc)
            arc = tf.cos(arc * m1 + m2) - m3
            arc = arc * s

            # softmax
            pred_softmax = dm.K.softmax(arc)
            return dm.losses.categorical_crossentropy(y_true, pred_softmax)

        train_model.compile(
            optimizer=dm.optimizers.Adam(lr=LR),
            loss=[dm.losses.dummy, arc_loss],
        )

        return train_model
Ejemplo n.º 5
0
    def cyclegan_model():
        def build_generator(nf=GENERATOR_CH, depth=DEPTH, name=None, ks=4):
            inputs = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3])
            # unet
            # outputs = dm.networks.UNet(inputs, [INPUT_SHAPE, INPUT_SHAPE, 3], nf=nf, ks=ks)
            # resnet
            outputs = dm.networks.ResNet50(inputs,
                                           [INPUT_SHAPE, INPUT_SHAPE, 3],
                                           nf=nf)
            # hourglass
            # outputs = dm.networks.Hourglass(inputs, [INPUT_SHAPE, INPUT_SHAPE, 3], nf=64, batch_norm='InstanceNormalization2D')
            return dm.Model(inputs, outputs, name=name)

        def build_discriminator(nf=DISCRIMINATOR_CH, depth=DEPTH, ks=4):
            inputs = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3])
            validity = dm.networks.Discriminator(inputs,
                                                 nf=nf,
                                                 depth=DEPTH,
                                                 ks=4)

            return dm.Model(inputs, validity)

        input_A = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3])
        input_B = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3])

        disc_A = dm.DeepMachine(build_discriminator(), name="disc_A")
        disc_B = dm.DeepMachine(build_discriminator(), name="disc_B")

        optimizer_disc = dm.optimizers.Adam(LR * W_DISC, 0.5, decay=LR_DECAY)
        optimizer_gen = dm.optimizers.Adam(LR, 0.5, decay=LR_DECAY)

        disc_A.compile(
            optimizer=optimizer_disc,
            loss=['mse'],
            metrics=['accuracy'],
        )
        disc_B.compile(
            optimizer=optimizer_disc,
            loss=['mse'],
            metrics=['accuracy'],
        )

        disc_A.trainable = False
        disc_B.trainable = False

        generator_AB = build_generator(name="generator_AB")
        generator_BA = build_generator(name="generator_BA")

        fake_A = generator_BA(input_B)
        fake_B = generator_AB(input_A)

        rec_A = generator_BA(fake_B)
        rec_B = generator_AB(fake_A)

        id_A = generator_BA(input_A)
        id_B = generator_AB(input_B)

        valid_A = disc_A(fake_A)
        valid_B = disc_B(fake_B)

        generator_model = dm.DeepMachine(
            inputs=[input_A, input_B],
            outputs=[valid_A, valid_B, rec_A, rec_B, id_A, id_B])

        generator_model.compile(
            optimizer=optimizer_gen,
            loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],
            loss_weights=[2, 2, 10., 10., 0., 0.],
        )

        return generator_model, generator_AB, generator_BA, disc_A, disc_B
Ejemplo n.º 6
0
    def build_model(inputs_channels=6, n_gpu=n_gpu):

        # define components
        ## image encoder
        def build_img_encoder():
            input_img = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3], name='input_img')

            img_embedding = dm.networks.Encoder2D(
                input_img, EMBEDING + CAMERA_PARAM, depth=4, nf=64)
            mesh_rec_embeding = dm.layers.Lambda(lambda x: x[..., :EMBEDING])(img_embedding)
            cam_rec_embeding = dm.layers.Lambda(lambda x: dm.K.tanh(x[..., EMBEDING:]) * 3)(img_embedding)

            return dm.Model(input_img, [mesh_rec_embeding, cam_rec_embeding], name='image_encoder')

        ## mesh encoder
        def build_mesh_encoder():
            input_mesh = dm.layers.Input(shape=[N_VERTICES, inputs_channels], name='input_mesh')
            mesh_embedding = dm.networks.MeshEncoder(
                input_mesh, EMBEDING, graph_laplacians, downsampling_matrices, filter_list=FILTERS)

            return dm.Model(input_mesh, mesh_embedding, name='mesh_encoder')

        ## common decoder
        def build_decoder():
            input_embeding = dm.layers.Input(shape=[EMBEDING], name='input_embeding')
            output_mesh = dm.networks.MeshDecoder(
                input_embeding, 
                inputs_channels, 
                graph_laplacians, 
                adj_matrices, 
                upsamling_matrices, 
                polynomial_order=6, 
                filter_list=FILTERS)

            return dm.Model(input_embeding, output_mesh, name='decoder')

        ## renderer
        def build_renderer(mesh_vertices, vertex_color, cam_parameter):
            # mesh_vertices = dm.layers.Input(shape=[N_VERTICES, 3], name='mesh_vertices')
            mesh_vertices.set_shape([BATCH_SIZE, N_VERTICES, 3])
            # vertex_color = dm.layers.Input(shape=[N_VERTICES, 3], name='vertex_color')
            vertex_color.set_shape([BATCH_SIZE, N_VERTICES, 3])
            # cam_parameter = dm.layers.Input(shape=[CAMERA_PARAM], name='cam_parameter')
            cam_parameter.set_shape([BATCH_SIZE, CAMERA_PARAM])

            # Build vertices and normals
            mesh_normals = tf.nn.l2_normalize(mesh_vertices, axis=2)

            # rendering output
            mesh_triangles = tf.constant(trilist, dtype=tf.int32)
            
            # camera position:
            eye = cam_parameter[...,:3]
            center = cam_parameter[...,3:6]
            world_up = cam_parameter[...,6:9]
            light_positions = cam_parameter[:,None,9:12]

            ambient_colors = tf.ones([BATCH_SIZE, 3], dtype=tf.float32) * 0.1
            light_intensities = tf.ones([BATCH_SIZE, 1, 3], dtype=tf.float32)

            render_mesh = dm.layers.Renderer(
                # image size
                image_width=INPUT_SHAPE,
                image_height=INPUT_SHAPE,
                # mesh definition
                triangles=mesh_triangles,
                normals=mesh_normals,
                # colour definition
                diffuse_colors=vertex_color,
                ambient_color=ambient_colors,
                # camera definition
                camera_position=eye,
                camera_lookat=center,
                camera_up=world_up,
                # light definition
                light_positions=light_positions,
                light_intensities=light_intensities,
            )(mesh_vertices)

            render_mesh = dm.layers.Lambda(lambda x: x[..., :3])(render_mesh)

            return render_mesh

        # Mesh AE stream
        ## define inputs
        input_mesh_stream = dm.layers.Input(shape=[N_VERTICES, 6], name='input_mesh_stream')

        ## define components
        mesh_encoder_model = build_mesh_encoder()
        decoder_model = build_decoder()

        ## define connections
        output_mesh = decoder_model(mesh_encoder_model(input_mesh_stream))
        mesh_ae_model = dm.DeepMachine(
            inputs=input_mesh_stream, 
            outputs=output_mesh,
            name='MeshStream'
        )

        ## multi gpu support
        if n_gpu > 1:
            mesh_ae_model = multi_gpu_model(mesh_ae_model, gpus=n_gpu)

        ## compile mesh stream
        mesh_ae_model.compile(
            optimizer=dm.optimizers.Adam(lr=LR),
            loss=['mae']
        )

        ## set trainable
        mesh_ae_model.trainable = False
        decoder_model.trainable = False
        mesh_encoder_model.trainable = False

        # Render Stream
        ## define inputs
        input_image_stream = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3], name='input_image_stream')

        ## define components
        img_encoder_model = build_img_encoder()

        ## define connections
        rec_mesh_emb, rec_cam_emb = img_encoder_model(input_image_stream)
        mesh_with_colour = decoder_model(rec_mesh_emb)
        
        mesh_vert = dm.layers.Lambda(lambda x: x[..., :3])(mesh_with_colour)
        mesh_vert.set_shape([BATCH_SIZE, N_VERTICES, 3])
        mesh_colour = dm.layers.Lambda(lambda x: x[..., 3:])(mesh_with_colour)
        mesh_colour.set_shape([BATCH_SIZE, N_VERTICES, 3])
        rec_render = build_renderer(
            mesh_vert,
            mesh_colour,
            rec_cam_emb
        )

        render_model = dm.DeepMachine(
            inputs=input_image_stream, 
            outputs=[rec_render, mesh_with_colour],
            name='ImageStream'
        )
        
        ## multi gpu support
        if n_gpu > 1:
            render_model = multi_gpu_model(render_model, gpus=n_gpu)
        
        ## compile render stream
        render_model.compile(
            optimizer=dm.optimizers.Adam(lr=LR),
            loss=['mae', dm.losses.dummy]
        )

        return render_model, mesh_ae_model, img_encoder_model
Ejemplo n.º 7
0
    def model_builder():
        optimizer = dm.optimizers.Adam(lr=LR, clipnorm=1., decay=0.)

        if FLAGS.use_ae and FLAGS.ae_path:

            # encoder
            ae_input = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3],
                                       name='ae_input')
            embeding_output = dm.networks.Encoder2D(ae_input,
                                                    128,
                                                    depth=8,
                                                    nf=32,
                                                    batch_norm=False)
            encoder_model = dm.Model(inputs=[ae_input],
                                     outputs=[embeding_output],
                                     name='encoder_model')

            # decoder
            input_embeding = dm.layers.Input(shape=[
                128,
            ],
                                             name='ae_input_embeding')
            ae_output = dm.networks.Decoder2D(input_embeding,
                                              [INPUT_SHAPE, INPUT_SHAPE, 3],
                                              depth=8,
                                              nf=32,
                                              batch_norm=False)
            decoder_model = dm.Model(inputs=[input_embeding],
                                     outputs=[ae_output],
                                     name='decoder_model')

            # combined model
            ae_model = dm.DeepMachine(
                inputs=[ae_input],
                outputs=[decoder_model(encoder_model(ae_input))])
            ae_model.compile(optimizer=optimizer, loss=['mae'])
            ae_model.trainable = False

        input_image = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3],
                                      name='input_image')
        uvxyz_prediction = dm.networks.Hourglass(
            input_image, [256, 256, 3],
            nf=64,
            batch_norm='InstanceNormalization2D')

        merged_input = dm.layers.Concatenate()([input_image, uvxyz_prediction])

        uvxyz_prediction_refine = dm.networks.Hourglass(
            merged_input, [256, 256, 3],
            nf=64,
            batch_norm='InstanceNormalization2D')

        outputs = [uvxyz_prediction, uvxyz_prediction_refine]
        if FLAGS.use_ae and FLAGS.ae_path:
            uvxyz_prediction_ae = encoder_model(uvxyz_prediction_refine)
            outputs.append(uvxyz_prediction_ae)

        train_model = dm.DeepMachine(inputs=input_image, outputs=outputs)

        def weighted_uv_loss(y_true, y_pred):

            loss = dm.K.mean(weight_mask * dm.K.abs(y_true - y_pred))

            return loss

        train_model.compile(optimizer=optimizer,
                            loss=[weighted_uv_loss, weighted_uv_loss, 'mae'],
                            loss_weights=[1, 1, 1])

        if FLAGS.use_ae and FLAGS.ae_path:
            return train_model, ae_model, encoder_model, decoder_model
        else:
            return train_model
Ejemplo n.º 8
0
def build_model(FLAGS,
                N_VERTICES,
                INPUT_SHAPE,
                EMBEDING,
                FILTERS,
                inputs_channels=6):
    LR = FLAGS.lr
    n_gpu = len(FLAGS.gpu.split(','))
    graph_laplacians, downsampling_matrices, upsamling_matrices, adj_matrices = mio.import_pickle(
        FLAGS.meta_path + '/lsfm_LDUA.pkl', encoding='latin1')

    # define components
    ## image encoder
    def build_img_encoder():
        input_img = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3],
                                    name='input_img')

        img_embedding = dm.networks.Encoder2D(input_img,
                                              EMBEDING,
                                              depth=4,
                                              nf=32)

        return dm.Model(input_img, img_embedding, name='image_encoder')

    ## mesh encoder
    def build_mesh_encoder():
        input_mesh = dm.layers.Input(shape=[N_VERTICES, inputs_channels],
                                     name='input_mesh')
        mesh_embedding = dm.networks.MeshEncoder(input_mesh,
                                                 EMBEDING,
                                                 graph_laplacians,
                                                 downsampling_matrices,
                                                 filter_list=FILTERS)

        return dm.Model(input_mesh, mesh_embedding, name='mesh_encoder')

    ## common decoder
    def build_decoder():
        input_embeding = dm.layers.Input(shape=[EMBEDING],
                                         name='input_embeding')
        output_mesh = dm.networks.MeshDecoder(input_embeding,
                                              inputs_channels,
                                              graph_laplacians,
                                              adj_matrices,
                                              upsamling_matrices,
                                              polynomial_order=6,
                                              filter_list=FILTERS)

        return dm.Model(input_embeding, output_mesh, name='decoder')

    # Mesh AE stream
    ## define inputs
    input_mesh = dm.layers.Input(shape=[N_VERTICES, 6], name='input_mesh')
    input_image = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3],
                                  name='input_image')

    ## define components
    img_encoder_model = build_img_encoder()
    mesh_encoder_model = build_mesh_encoder()
    decoder_model = build_decoder()

    ## define connections
    output_mesh_ae = decoder_model(mesh_encoder_model(input_mesh))
    output_mesh_3dmm = decoder_model(img_encoder_model(input_image))

    ## custom losses
    def masked_mae(gt_y, pred_y):
        gt_mask = gt_y[..., -1:]
        gt_mesh = gt_y[..., :-1]
        return dm.losses.mae(gt_mesh * gt_mask, pred_y * gt_mask)

    # model definition
    model_ae_mesh = dm.DeepMachine(inputs=[input_mesh],
                                   outputs=[output_mesh_ae],
                                   name='MeshAutoEncoder')
    if n_gpu > 1:
        model_ae_mesh = multi_gpu_model(model_ae_mesh, gpus=n_gpu)

    model_ae_mesh.compile(optimizer=dm.optimizers.Adam(lr=LR),
                          loss=[masked_mae])

    mesh_encoder_model.trainable = False
    decoder_model.trainable = False
    model_ae_mesh.trainable = False
    model_3dmm = dm.DeepMachine(inputs=[input_image],
                                outputs=[output_mesh_3dmm],
                                name='3DMM')

    ## multi gpu support
    if n_gpu > 1:
        model_3dmm = multi_gpu_model(model_3dmm, gpus=n_gpu)

    ## compile mesh stream
    model_3dmm.compile(optimizer=dm.optimizers.Adam(lr=LR * FLAGS.m_weight),
                       loss=[masked_mae])

    return model_ae_mesh, model_3dmm, decoder_model
Ejemplo n.º 9
0
    def build_model(inputs_channels=3):
        input_mesh = dm.layers.Input(shape=[N_VERTICES, inputs_channels],
                                     name='input_mesh')

        mesh_embedding = dm.networks.MeshEncoder(input_mesh,
                                                 EMBEDING,
                                                 graph_laplacians,
                                                 downsampling_matrices,
                                                 filter_list=FILTERS)
        output_mesh = dm.networks.MeshDecoder(mesh_embedding,
                                              inputs_channels,
                                              graph_laplacians,
                                              adj_matrices,
                                              upsamling_matrices,
                                              polynomial_order=6,
                                              filter_list=FILTERS)

        # wrapping input and output
        mesh_ae = dm.DeepMachine(inputs=input_mesh, outputs=[output_mesh])

        n_gpu = len(FLAGS.gpu.split(','))
        if n_gpu > 1:
            mesh_ae = multi_gpu_model(mesh_ae, gpus=n_gpu)

        # compile model with optimizer
        mesh_ae.compile(optimizer=dm.optimizers.Adam(lr=LR), loss=['mae'])

        # ---------------- rendering layer ------------
        mesh_to_render = dm.layers.Input(shape=[N_VERTICES, 3],
                                         name='mesh_to_render')
        mesh_to_render.set_shape([BATCH_SIZE, N_VERTICES, 3])
        vertex_color = dm.layers.Input(shape=[N_VERTICES, 3],
                                       name='vertex_color')
        vertex_color.set_shape([BATCH_SIZE, N_VERTICES, 3])

        # Build vertices and normals
        mesh_vertices = mesh_to_render
        mesh_vertices.set_shape([BATCH_SIZE, N_VERTICES, 3])
        mesh_normals = tf.nn.l2_normalize(mesh_vertices, axis=2)
        mesh_normals.set_shape([BATCH_SIZE, N_VERTICES, 3])

        # rendering output
        mesh_triangles = tf.constant(trilist, dtype=tf.int32)

        # camera position:
        eye = tf.constant(BATCH_SIZE * [[0.0, 0.0, -2.0]], dtype=tf.float32)
        center = tf.constant(BATCH_SIZE * [[0.0, 0.0, 0.0]], dtype=tf.float32)
        world_up = tf.constant(BATCH_SIZE * [[1.0, 0.0, 0.0]],
                               dtype=tf.float32)
        ambient_colors = tf.constant(BATCH_SIZE * [[1., 1., 1.]],
                                     dtype=tf.float32) * 0.1
        light_positions = tf.constant(BATCH_SIZE * [[[2.0, 2.0, 2.0]]]) * 3.
        light_intensities = tf.ones([BATCH_SIZE, 1, 3], dtype=tf.float32)

        render_mesh = dm.layers.Renderer(
            # image size
            image_width=256,
            image_height=256,
            # mesh definition
            triangles=mesh_triangles,
            normals=mesh_normals,
            # colour definition
            diffuse_colors=vertex_color,
            ambient_color=ambient_colors,
            # camera definition
            camera_position=eye,
            camera_lookat=center,
            camera_up=world_up,
            # light definition
            light_positions=light_positions,
            light_intensities=light_intensities,
        )(mesh_vertices)

        mesh_render = dm.DeepMachine(inputs=[mesh_to_render, vertex_color],
                                     outputs=[render_mesh])
        # ----------------------

        return mesh_ae, mesh_render
Ejemplo n.º 10
0
    def star_gan_model():
        def build_generator(nf=GENERATOR_CH, depth=DEPTH, name=None, ks=4):
            inputs = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3])
            target_label = dm.layers.Input(shape=[N_CLASSES])
            target_label_conv = dm.layers.RepeatVector(
                INPUT_SHAPE * INPUT_SHAPE)(target_label)
            target_label_conv = dm.layers.Reshape(
                [INPUT_SHAPE, INPUT_SHAPE, N_CLASSES])(target_label_conv)
            merged_inputs = dm.layers.Concatenate()(
                [inputs, target_label_conv])
            # unet
            # outputs = dm.networks.UNet(inputs, [INPUT_SHAPE, INPUT_SHAPE, 3], nf=nf, ks=ks)
            # resnet
            outputs = dm.networks.ResNet50(merged_inputs,
                                           [INPUT_SHAPE, INPUT_SHAPE, 3],
                                           nf=nf,
                                           n_residule=6)
            # hourglass
            # outputs = dm.networks.Hourglass(inputs, [INPUT_SHAPE, INPUT_SHAPE, 3], nf=64, batch_norm='InstanceNormalization2D')
            return dm.Model([inputs, target_label], outputs, name=name)

        def build_discriminator(nf=DISCRIMINATOR_CH, depth=DEPTH, ks=4):
            inputs = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3])
            validity, conv_feature = dm.networks.Discriminator(
                inputs, nf=nf, depth=depth, ks=4, return_endpoints=True)
            classes = dm.networks.conv2d(conv_feature,
                                         N_CLASSES,
                                         DISC_SHAPE,
                                         padding='valid',
                                         activation='softmax')
            classes = dm.layers.Reshape([N_CLASSES])(classes)

            return dm.Model(inputs, [validity, classes])

        def binary_crossentropy_none(y_true, y_pred):

            return dm.K.maximum(
                dm.K.mean(dm.K.binary_crossentropy(y_true, y_pred), axis=-1),
                0)

        input_image = dm.layers.Input(shape=[INPUT_SHAPE, INPUT_SHAPE, 3])
        target_label = dm.layers.Input(shape=[N_CLASSES])
        original_label = dm.layers.Input(shape=[N_CLASSES])

        disc_model = dm.DeepMachine(build_discriminator(), name="disc")
        gen_optimizer = dm.optimizers.Adam(LR, 0.5, decay=LR_DECAY)
        dis_optimizer = dm.optimizers.Adam(LR, 0.5, decay=LR_DECAY)

        disc_model.compile(
            optimizer=dis_optimizer,
            loss=['mse', binary_crossentropy_none],
        )

        disc_model.trainable = False

        generator = build_generator(name="generator")

        fake_img = generator([input_image, target_label])
        rec_img = generator([fake_img, original_label])
        valid_img, fake_classes = disc_model(fake_img)

        generator_model = dm.DeepMachine(
            inputs=[input_image, target_label, original_label],
            outputs=[rec_img, valid_img, fake_classes])

        generator_model.compile(
            optimizer=gen_optimizer,
            loss=['mae', 'mse', 'categorical_crossentropy'],
            loss_weights=[10., 1., 1.],
        )

        return generator_model, generator, disc_model