예제 #1
0
파일: models.py 프로젝트: wonsang/placenta
    def _new_model(self):
        inputs = layers.Input(shape=self.input_size)

        conv1 = layers.Conv3D(16, (3, 3, 3), activation='relu',
                              padding='same')(inputs)
        conv1 = layers.Conv3D(16, (3, 3, 3), activation='relu',
                              padding='same')(conv1)
        pool1 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)

        conv2 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(pool1)
        conv2 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(conv2)
        pool2 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv2)

        conv3 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(pool2)
        conv3 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(conv3)
        pool3 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv3)

        conv4 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(pool3)
        conv4 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv4)

        up5 = layers.Conv3DTranspose(64, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv4)
        conc5 = layers.concatenate([up5, conv3])
        conv5 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(conc6)
        conv5 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(conv5)

        up6 = layers.Conv3DTranspose(32, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv5)
        conc6 = layers.concatenate([up6, conv2])
        conv6 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(conc6)
        conv6 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(conv6)

        up7 = layers.Conv3DTranspose(16, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv6)
        conc7 = layers.concatenate([up7, conv1])
        conv7 = layers.Conv3D(16, (3, 3, 3), activation='relu',
                              padding='same')(conc7)
        conv7 = layers.Conv3D(16, (3, 3, 3), activation='relu',
                              padding='same')(conv7)

        outputs = layers.Conv3D(1, (1, 1, 1), activation='sigmoid')(conv7)

        self.model = Model(inputs=inputs, outputs=outputs)
예제 #2
0
def build_cnn(optimizer='adam', lr=0.00002):
    """Main class for setting up a CNN. Returns the compiled model."""
    importlib.reload(config)

    C = config.Config()

    proj = layers.Input(C.proj_dims)
    #x = layers.Permute((2,1,3))(img)
    x = layers.Reshape((C.proj_dims[0], -1))(proj)
    x = layers.Dense(1024, activation='tanh')(
        x)  #, kernel_regularizer=regularizers.l1(0.01)
    x = layers.BatchNormalization()(x)
    #x = layers.Reshape((C.proj_dims[0],32,-1))(x)
    #x = layers.Conv2D(128, 3, activation='relu', padding='same')(x)
    #x = layers.Reshape((C.proj_dims[0],-1))(x)
    x = layers.Dense(1024, activation='tanh')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Reshape((C.proj_dims[0], 32, 32, -1))(x)
    x = layers.Conv3D(64, 3, activation='relu', padding='same')(x)
    #x = layers.UpSampling3D((1,2,2))(x)
    x = layers.MaxPooling3D((2, 1, 1))(x)
    x = layers.Conv3D(64, 3, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3DTranspose(1, 3, activation='sigmoid', padding='same')(x)
    img = layers.Reshape(C.world_dims)(x)
    #x = layers.Lambda(norm)(x)
    #x = layers.Permute((2,1,3))(x)
    #x = layers.Conv2D(64, (2,2), activation='relu', padding='same')(x)
    #x = layers.Conv2D(64, (2,2), padding='same')(x)

    model = Model(proj, img)
    model.compile(optimizer=RMSprop(lr=lr, decay=0.1), loss='mse')

    if False:
        x = layers.Reshape((C.proj_dims[0], -1))(proj)
        x = layers.Dense(1024, activation='tanh')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Dense(1024, activation='tanh')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Reshape((C.proj_dims[0], 32, 32, -1))(x)
        x = layers.Conv3D(64, (3, 3, 3), activation='relu', padding='same')(x)
        x = layers.UpSampling3D((1, 2, 2))(x)
        x = layers.Conv3D(64, (3, 3, 3), activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv3DTranspose(1, (1, 3, 3),
                                   activation='sigmoid',
                                   padding='same')(x)

    return model
def SkipDeconvBlock(x, skip_x, channels):

    out = layers.Conv3DTranspose(channels // 2, (2, 2, 2),
                                 padding='valid', strides=2)(x)
    out = layers.LeakyReLU()(out)
    cat = layers.Concatenate(axis=-1)([out, skip_x])
    out = layers.Conv3D(channels, (3, 3, 3), padding='same', strides=1)(cat)

    out = layers.LeakyReLU()(out)
    out = layers.Add()([out, cat])

    return out
예제 #4
0
def CNN():
    num_channels = 1
    num_mask_channels = 1
    img_shape = (None, None, None, 1)
    
    inputs = Input(shape = img_shape)
    conv1 = layers.Conv3D(32, 3, padding='same')(inputs)
    conv1 = layers.BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    conv1 = layers.Conv3D(32, 3, padding='same')(conv1)
    conv1 = layers.BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    pool1 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)
    
    conv2 = layers.Conv3D(64, 3, padding='same')(pool1)
    conv2 = layers.BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = layers.Conv3D(64, 3, padding='same')(conv2)
    conv2 = layers.BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    pool2 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv2)

    conv3 = layers.Conv3D(128, 3, padding='same')(pool2)
    conv3 = layers.BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = layers.Conv3D(128, 3, padding='same')(conv3)
    conv3 = layers.BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 =  layers.UpSampling3D(size=(2, 2, 2))(conv3)

    up4  = layers.concatenate([conv3, conv2])
    conv4 = layers.Conv3DTranspose(64, 3, padding='same')(up4)
    conv4 = layers.BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = layers.Conv3DTranspose(64, 3, padding='same')(conv4)
    conv4 = layers.BatchNormalization()(conv4)##conv ou crop
    conv4 = Activation('relu')(conv4)
    conv4 = layers.Conv3DTranspose(64, 1, padding='same')(conv4)
    conv4 =  layers.UpSampling3D(size=(2, 2, 2))(conv4)

    up5  = layers.concatenate([conv4, conv1])
    conv5 = layers.Conv3DTranspose(32, 3, padding='same')(up5)
    conv5 = layers.BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = layers.Conv3DTranspose(32, 3, padding='same')(conv5)
    conv5 = layers.BatchNormalization()(conv5)##conv ou crop
    conv5 = Activation('relu')(conv5)
    conv5 = layers.Conv3DTranspose(1, 1, padding='same', activation='relu')(conv5)

    model = Model(inputs=inputs, outputs=conv5)
    #model.summary()

    return(model)
예제 #5
0
def generate_model():
    init = keras.initializers.glorot_uniform()

    training = K.variable(True, name="Training")

    input_ = layers.Input(shape=(SIZE, SIZE, SIZE, 1),
                          dtype="float32",
                          name="img")
    #  input_ = layers.InputLayer(input_tensor=input_)

    out = layers.Conv3D(
        filters=8,
        kernel_size=5,
        activation="relu",
        kernel_initializer=init,
        padding="same",
    )(input_)
    out = layers.Conv3D(
        filters=8,
        kernel_size=5,
        activation="relu",
        kernel_initializer=init,
        padding="same",
    )(out)

    conv1 = out

    out = layers.MaxPooling3D(pool_size=2, strides=2)(out)
    out = layers.Dropout(rate=0.3)(out)
    out = layers.Conv3D(
        filters=SIZE // 8,
        kernel_size=5,
        activation="relu",
        kernel_initializer=init,
        padding="same",
    )(out)
    out = layers.Conv3D(
        filters=SIZE // 8,
        kernel_size=5,
        activation="relu",
        kernel_initializer=init,
        padding="same",
    )(out)

    conv2 = out

    out = layers.MaxPooling3D(pool_size=2, strides=2)(out)
    out = layers.Dropout(rate=0.3)(out)
    out = layers.Conv3D(
        filters=SIZE // 4,
        kernel_size=5,
        activation="relu",
        kernel_initializer=init,
        padding="same",
    )(out)
    out = layers.Conv3D(
        filters=SIZE // 4,
        kernel_size=5,
        activation="relu",
        kernel_initializer=init,
        padding="same",
    )(out)

    conv3 = out

    out = layers.MaxPooling3D(pool_size=2, strides=2)(out)
    out = layers.Dropout(rate=0.3)(out)

    out = layers.Conv3DTranspose(
        filters=SIZE // 4,
        kernel_size=5,
        strides=2,
        kernel_initializer=init,
        padding="same",
        use_bias=False,
    )(out)
    out = layers.concatenate([out, conv3], axis=-1)
    out = layers.Conv3D(
        filters=SIZE // 4,
        kernel_size=5,
        activation="relu",
        kernel_initializer=init,
        padding="same",
    )(out)

    out = layers.Dropout(rate=0.3)(out)

    out = layers.Conv3DTranspose(
        filters=SIZE // 8,
        kernel_size=5,
        strides=2,
        kernel_initializer=init,
        padding="same",
        use_bias=False,
    )(out)
    out = layers.concatenate([out, conv2], axis=-1)
    out = layers.Conv3D(
        filters=SIZE // 8,
        kernel_size=5,
        activation="relu",
        kernel_initializer=init,
        padding="same",
    )(out)

    out = layers.Dropout(rate=0.3)(out)

    out = layers.Conv3DTranspose(
        filters=8,
        kernel_size=5,
        strides=2,
        kernel_initializer=init,
        padding="same",
        use_bias=False,
    )(out)
    out = layers.concatenate([out, conv1], axis=-1)
    out = layers.Conv3D(
        filters=8,
        kernel_size=5,
        activation="relu",
        kernel_initializer=init,
        padding="same",
    )(out)

    out = layers.Dropout(rate=0.3)(out)
    out = layers.Conv3D(filters=1,
                        kernel_size=1,
                        kernel_initializer=init,
                        padding="same")(out)

    #  out = layers.Activation("sigmoid")(out)
    out = layers.Dense(1, activation="sigmoid")(out)

    model = keras.models.Model(input_, out)
    model.compile("adam", loss="binary_crossentropy", metrics=['accuracy'])

    return model
예제 #6
0
파일: models.py 프로젝트: wonsang/placenta
    def _new_model(self):
        inputs = layers.Input(shape=self.input_size)

        conv1 = layers.Conv3D(16, (3, 3, 3), activation='relu',
                              padding='same')(inputs)
        conv1 = layers.Conv3D(16, (3, 3, 3), activation='relu',
                              padding='same')(conv1)
        pool1 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)

        conv2 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(pool1)
        conv2 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(conv2)
        pool2 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv2)

        conv3 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(pool2)
        conv3 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(conv3)
        pool3 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv3)

        conv4 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(pool3)
        conv4 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv4)
        pool4 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv4)

        conv5 = layers.Conv3D(256, (3, 3, 3),
                              activation='relu',
                              padding='same')(pool4)
        conv5 = layers.Conv3D(256, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv5)
        pool5 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv5)

        conv6 = layers.Conv3D(512, (3, 3, 3),
                              activation='relu',
                              padding='same')(pool5)
        conv6 = layers.Conv3D(512, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv6)

        up7 = layers.Conv3DTranspose(256, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv6)
        conc7 = layers.concatenate([up7, conv5])
        conv7 = layers.Conv3D(256, (3, 3, 3),
                              activation='relu',
                              padding='same')(conc7)
        conv7 = layers.Conv3D(256, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv7)

        up8 = layers.Conv3DTranspose(128, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv7)
        conc8 = layers.concatenate([up8, conv4])
        conv8 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(conc8)
        conv8 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv8)

        up9 = layers.Conv3DTranspose(64, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv8)
        conc9 = layers.concatenate([up9, conv3])
        conv9 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(conc9)
        conv9 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(conv9)

        up10 = layers.Conv3DTranspose(32, (2, 2, 2),
                                      strides=(2, 2, 2),
                                      padding='same')(conv9)
        conc10 = layers.concatenate([up10, conv2])
        conv10 = layers.Conv3D(32, (3, 3, 3),
                               activation='relu',
                               padding='same')(conc10)
        conv10 = layers.Conv3D(32, (3, 3, 3),
                               activation='relu',
                               padding='same')(conv10)

        up11 = layers.Conv3DTranspose(16, (2, 2, 2),
                                      strides=(2, 2, 2),
                                      padding='same')(conv10)
        conc11 = layers.concatenate([up11, conv1])
        conv11 = layers.Conv3D(16, (3, 3, 3),
                               activation='relu',
                               padding='same')(conc11)
        conv11 = layers.Conv3D(16, (3, 3, 3),
                               activation='relu',
                               padding='same')(conv11)

        outputs = layers.Conv3D(1, (1, 1, 1), activation='sigmoid')(conv11)

        self.model = Model(inputs=inputs, outputs=outputs)
예제 #7
0
    def _new_model(self):
        inputs = layers.Input(shape=self.input_size)

        conv1 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(inputs)
        conv1 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(conv1)
        pool1 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)

        conv2 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(pool1)
        conv2 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(conv2)
        pool2 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv2)

        conv3 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(pool2)
        conv3 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv3)
        pool3 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv3)

        conv4 = layers.Conv3D(256, (3, 3, 3),
                              activation='relu',
                              padding='same')(pool3)
        conv4 = layers.Conv3D(256, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv4)
        pool4 = layers.MaxPooling3D(pool_size=(2, 2, 2))(conv4)

        conv5 = layers.Conv3D(512, (3, 3, 3),
                              activation='relu',
                              padding='same')(pool4)
        conv5 = layers.Conv3D(512, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv5)

        up6 = layers.Conv3DTranspose(256, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv5)
        conc6 = layers.concatenate([up6, conv4])
        conv6 = layers.Conv3D(256, (3, 3, 3),
                              activation='relu',
                              padding='same')(conc6)
        conv6 = layers.Conv3D(256, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv6)

        up7 = layers.Conv3DTranspose(128, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv6)
        conc7 = layers.concatenate([up7, conv3])
        conv7 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(conc7)
        conv7 = layers.Conv3D(128, (3, 3, 3),
                              activation='relu',
                              padding='same')(conv7)

        up8 = layers.Conv3DTranspose(64, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv7)
        conc8 = layers.concatenate([up8, conv2])
        conv8 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(conc8)
        conv8 = layers.Conv3D(64, (3, 3, 3), activation='relu',
                              padding='same')(conv8)

        up9 = layers.Conv3DTranspose(32, (2, 2, 2),
                                     strides=(2, 2, 2),
                                     padding='same')(conv8)
        conc9 = layers.concatenate([up9, conv1])
        conv9 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(conc9)
        conv9 = layers.Conv3D(32, (3, 3, 3), activation='relu',
                              padding='same')(conv9)

        outputs = layers.Conv3D(1, (1, 1, 1),
                                activation='sigmoid',
                                name='outputs')(conv9)

        ae_up6 = layers.Conv3DTranspose(256, (2, 2, 2),
                                        strides=(2, 2, 2),
                                        padding='same')(conv5)
        ae_conv6 = layers.Conv3D(256, (3, 3, 3),
                                 activation='relu',
                                 padding='same')(ae_up6)
        ae_conv6 = layers.Conv3D(256, (3, 3, 3),
                                 activation='relu',
                                 padding='same')(ae_conv6)

        ae_up7 = layers.Conv3DTranspose(128, (2, 2, 2),
                                        strides=(2, 2, 2),
                                        padding='same')(ae_conv6)
        ae_conv7 = layers.Conv3D(128, (3, 3, 3),
                                 activation='relu',
                                 padding='same')(ae_up7)
        ae_conv7 = layers.Conv3D(128, (3, 3, 3),
                                 activation='relu',
                                 padding='same')(ae_conv7)

        ae_up8 = layers.Conv3DTranspose(64, (2, 2, 2),
                                        strides=(2, 2, 2),
                                        padding='same')(ae_conv7)
        ae_conv8 = layers.Conv3D(64, (3, 3, 3),
                                 activation='relu',
                                 padding='same')(ae_up8)
        ae_conv8 = layers.Conv3D(64, (3, 3, 3),
                                 activation='relu',
                                 padding='same')(ae_conv8)

        ae_up9 = layers.Conv3DTranspose(32, (2, 2, 2),
                                        strides=(2, 2, 2),
                                        padding='same')(ae_conv8)
        ae_conv9 = layers.Conv3D(32, (3, 3, 3),
                                 activation='relu',
                                 padding='same')(ae_up9)
        ae_conv9 = layers.Conv3D(32, (3, 3, 3),
                                 activation='relu',
                                 padding='same')(ae_conv9)

        ae_outputs = layers.Conv3D(1, (1, 1, 1),
                                   activation='sigmoid',
                                   name='ae_outputs')(ae_conv9)

        self.model = Model(inputs=inputs, outputs=[outputs, ae_outputs])
def Conv_VAE3D(n_epochs=2,
               batch_size=10,
               learning_rate=0.001,
               decay_rate=0.0,
               latent_dim=8,
               name='stats.pickle'):

    # Prepare session:
    K.clear_session()

    # Number of samples to use for training and validation:
    n_train = 1500
    n_val = 1000

    # ENCODER: ---------------------------------------------------------------

    input_img = Input(shape=(50, 50, 50, 4), name="Init_Input")
    x = layers.Conv3D(32, (3, 3, 3),
                      padding="same",
                      activation='relu',
                      name='E_Conv1')(input_img)
    x = layers.MaxPooling3D((2, 2, 2), name='E_MP1')(x)
    x = layers.Conv3D(64, (3, 3, 3),
                      padding="same",
                      activation='relu',
                      name='E_Conv2')(x)
    x = layers.MaxPooling3D((2, 2, 2), name='E_MP2')(x)
    x = layers.Conv3D(64, (3, 3, 3),
                      padding="valid",
                      activation='relu',
                      name='E_Conv3')(x)
    x = layers.MaxPooling3D((2, 2, 2), name='E_MP3')(x)
    x = layers.Conv3D(128, (3, 3, 3),
                      padding="same",
                      activation='relu',
                      name='E_Conv4')(x)

    shape_before_flattening = K.int_shape(x)

    x = layers.Flatten()(x)
    x = layers.Dense(32, activation='relu')(x)

    encoder = Model(input_img, x)

    print(encoder.summary())

    # VARIATIONAL LAYER: ------------------------------------------------------

    z_mean = layers.Dense(latent_dim, name='V_Mean')(x)
    z_log_var = layers.Dense(latent_dim, name='V_Sig')(x)

    def sampling(args):
        z_mean, z_log_var = args
        epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                                  mean=0.,
                                  stddev=1.)
        return z_mean + K.exp(z_log_var) * epsilon

    z = layers.Lambda(sampling, name='V_Var')([z_mean, z_log_var])
    variation = Model(input_img, z)

    print(variation.summary())
    # DECODER: ---------------------------------------------------------------

    decoder_input = layers.Input(shape=(latent_dim, ), name='D_Input')

    x = layers.Dense(np.prod(shape_before_flattening[1:]),
                     activation='relu',
                     name='D_Dense')(decoder_input)
    x = layers.Reshape(shape_before_flattening[1:], name='D_UnFlatten')(x)
    x = layers.Conv3DTranspose(32,
                               3,
                               padding='same',
                               activation='relu',
                               name='D_DeConv1')(x)
    x = layers.UpSampling3D((2, 2, 2))(x)
    x = layers.Conv3D(4,
                      3,
                      padding='same',
                      activation='sigmoid',
                      name='D_Conv1')(x)
    x = layers.UpSampling3D((5, 5, 5))(x)
    x = layers.Conv3D(4,
                      3,
                      padding='same',
                      activation='sigmoid',
                      name='D_Conv2')(x)

    decoder = Model(decoder_input, x)

    print(decoder.summary())

    # CALLBACKS: --------------------------------------------------------------

    class TimeHistory(keras.callbacks.Callback):
        start = []
        end = []
        times = []

        def on_epoch_begin(self, batch, logs=None):
            self.start = time.time()

        def on_epoch_end(self, batch, logs=None):
            self.end = time.time()
            self.times.append(self.end - self.start)

# CUSTOM LAYERS: ----------------------------------------------------------

    class CustomVariationalLayer(keras.layers.Layer):
        def vae_loss(self, x, z_decoded):
            x = K.flatten(x)
            z_decoded = K.flatten(z_decoded)
            xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
            kl_loss = -5e-4 * K.mean(
                1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
            return K.mean(xent_loss + kl_loss)  #xent_loss) # + kl_loss)

        def call(self, inputs):
            x = inputs[0]
            z_decoded = inputs[1]
            loss = self.vae_loss(x, z_decoded)
            self.add_loss(loss, inputs=inputs)
            return x

# DEFINE FINAL MODEL: ----------------------------------------------------

    z_encoded = variation(input_img)
    z_decoded = decoder(z_encoded)

    # Construct Final Model:
    y = CustomVariationalLayer()([input_img, z_decoded])
    vae = Model(input_img, y)

    print(vae.summary())

    # Define Optimizer:
    vae_optimizer = keras.optimizers.Adam(lr=learning_rate,
                                          beta_1=0.9,
                                          beta_2=0.999,
                                          decay=decay_rate,
                                          amsgrad=False)

    vae.compile(optimizer=vae_optimizer,
                loss=None)  # Not using custom vae loss function defined above

    # Define time callback:
    time_callback = TimeHistory()

    steps = n_train // batch_size
    val_steps = n_val // batch_size
    # FIT MODEL: --------------------------------------------------------------
    history = vae.fit_generator(
        gen_batches(batch_size),
        shuffle=True,
        epochs=n_epochs,
        steps_per_epoch=steps,
        callbacks=[time_callback],
        validation_data=gen_batches_validation(batch_size),
        validation_steps=val_steps)

    # OUTPUTS: -------------------------------------------------------------

    history_dict = history.history

    loss_values = history_dict['loss']
    val_loss_values = history_dict['val_loss']
    times = time_callback.times
    data = {
        'train_loss': loss_values,
        'val_loss': val_loss_values,
        'epoch_time': times
    }

    pickle_out = open(name, "wb")
    pickle.dump(data, pickle_out)
    pickle_out.close()

    K.clear_session()
    return (history_dict)