Exemple #1
0
def main():
    # Define the model.
    model = keras.models.Sequential([
        layers.Input(shape=(None, 40, 40, 1)),
        layers.ConvLSTM2D(filters=40,
                          kernel_size=(3, 3),
                          padding="same",
                          return_sequences=True),
        layers.BatchNormalization(),
        layers.ConvLSTM2D(filters=40,
                          kernel_size=(3, 3),
                          padding="same",
                          return_sequences=True),
        layers.BatchNormalization(),
        layers.ConvLSTM2D(filters=40,
                          kernel_size=(3, 3),
                          padding="same",
                          return_sequences=True),
        layers.BatchNormalization(),
        layers.ConvLSTM2D(filters=40,
                          kernel_size=(3, 3),
                          padding="same",
                          return_sequences=True),
        layers.BatchNormalization(),
        layers.Conv3D(filters=1,
                      kernel_size=(3, 3, 3),
                      activation="sigmoid",
                      padding="same")
    ])
    model.compile(loss="binary_crossentropy", optimizer="adadelta")
    print(model.summary())

    # Generate artificial data (movies with 3 to 7 moving squares
    # inside). The squares are of shape 1x1 or 2x2 pixels, and move
    # linearly over time. For convenience, first create movies with
    # bigger width and height (80x80) and at the end select a 40x40
    # window.
    noisy_movies, shifted_movies = generate_movies(n_samples=1200)

    # Train the model.
    epochs = 1  # In practice, would need hundreds of epochs.
    model.fit(noisy_movies[:1000],
              shifted_movies[:1000],
              batch_size=10,
              epochs=epochs,
              verbose=2,
              validation_split=0.1)

    # Text the model on one movie.
    movie_index = 1004
    test_movie = noisy_movies[movie_index]

    # Start from first 7 frames.
    track = test_movie[:7, ::, ::, ::]

    # Predict 16 frames.
    for j in range(16):
        new_pos = model.predict(track[np.newaxis, ::, ::, ::, ::])
        new = new_pos[::, -1, ::, ::, ::]
        track = np.concatenate((track, new), axis=0)

    # Exit the program
    exit(0)
Exemple #2
0
def build_generator():
    input = tf.keras.Input(shape=(L_node, W_node, Time_steps))  # (None, 256, 256, 5)

    l1 = input  # (None, 256, 256, 5)

    l2_1 = layers.Conv2D(64, 1, 1, 'same', activation='relu')(l1)  # (None, 256, 256, 64)
    l2_1 = layers.BatchNormalization()(l2_1)

    l2_2 = layers.Conv2D(48, 1, 1, 'same', activation='relu')(l1)  # (None, 256, 256, 48)
    l2_2 = layers.BatchNormalization()(l2_2)
    l2_2 = layers.Conv2D(64, 3, 1, 'same', activation='relu')(l2_2)  # (None, 256, 256, 64)
    l2_2 = layers.BatchNormalization()(l2_2)

    l2_3 = layers.Conv2D(48, 1, 1, 'same', activation='relu')(l1)  # (None, 256, 256, 48)
    l2_3 = layers.BatchNormalization()(l2_3)
    l2_3 = layers.Conv2D(64, 5, 1, 'same', activation='relu')(l2_3)  # (None, 256, 256, 64)
    l2_3 = layers.BatchNormalization()(l2_3)

    l2_4 = layers.AvgPool2D(3, 1, 'same')(l1)
    l2_4 = layers.Conv2D(64, 1, 1, 'same', activation='relu')(l2_4)  # (None, 256, 256, 64)
    l2_4 = layers.BatchNormalization()(l2_4)

    l2 = layers.concatenate([l2_1, l2_2, l2_3, l2_4], 3)  # (None, 256, 256, 256)

    l3 = layers.Conv2D(64, 3, 1, 'same', activation='relu')(l2)  # (None, 256, 256, 64)

    l4_1 = layers.Conv2D(128, 3, 2, 'same')(l3)  # (None, 128, 128, 128)
    l4_1 = layers.BatchNormalization()(l4_1)
    l4_1 = layers.LeakyReLU(0.2)(l4_1)

    l5_1 = layers.Conv2D(256, 3, 2, 'same')(l4_1)  # (None, 64, 64, 256)
    l5_1 = layers.BatchNormalization()(l5_1)
    l5_1 = layers.LeakyReLU(0.2)(l5_1)

    l6_1 = layers.Conv2D(512, 3, 2, 'same')(l5_1)  # (None, 32, 32, 512)
    l6_1 = layers.BatchNormalization()(l6_1)
    l6_1 = layers.LeakyReLU(0.2)(l6_1)

    l7_1 = layers.Conv2D(512, 3, 2, 'same')(l6_1)  # (None, 16, 16, 512)
    l7_1 = layers.BatchNormalization()(l7_1)
    l7_1 = layers.LeakyReLU(0.2)(l7_1)

    l8_1 = layers.Conv2DTranspose(512, 3, 2, 'same')(l7_1)  # (None, 32, 32, 512)
    l8_1 = layers.BatchNormalization()(l8_1)
    l8_1 = layers.Dropout(0.3)(l8_1)
    l8_1 = layers.Activation('relu')(l8_1)

    l9_1 = layers.Conv2DTranspose(256, 3, 2, 'same')(l8_1)  # (None, 64, 64, 256)
    l9_1 = layers.BatchNormalization()(l9_1)
    l9_1 = layers.Dropout(0.3)(l9_1)
    l9_1 = layers.Activation('relu')(l9_1)

    l10_1 = layers.Conv2DTranspose(128, 3, 2, 'same')(l9_1)  # (None, 128, 128, 128)
    l10_1 = layers.BatchNormalization()(l10_1)
    l10_1 = layers.Dropout(0.3)(l10_1)
    l10_1 = layers.Activation('relu')(l10_1)

    l11_1 = layers.Conv2DTranspose(64, 3, 2, 'same')(l10_1)  # (None, 256, 256, 64)
    l11_1 = layers.BatchNormalization()(l11_1)
    l11_1 = layers.Dropout(0.3)(l11_1)
    l11_1 = layers.Activation('tanh')(l11_1)

    l4_2 = layers.Conv2D(32, 3, 2, 'same')(l3)  # (None, 128, 128, 32)

    l5_2 = layers.Conv2D(5, 3, 2, 'same')(l4_2)  # (None, 64, 64, 5)
    # l5_2_1 = tf.reshape(l5_2(-1, -1, -1, 0), shape=(1, L_node, W_node, Channel))
    # l5_2_2 = tf.reshape(l5_2(-1, -1, -1, 1), shape=(1, L_node, W_node, Channel))
    # l5_2_3 = tf.reshape(l5_2(-1, -1, -1, 2), shape=(1, L_node, W_node, Channel))
    # l5_2_4 = tf.reshape(l5_2(-1, -1, -1, 3), shape=(1, L_node, W_node, Channel))
    # l5_2_5 = tf.reshape(l5_2(-1, -1, -1, 4), shape=(1, L_node, W_node, Channel))
    l5_2 = tf.reshape(l5_2, shape=(-1, 5, 64, 64, 1))  # (None, 5, 64, 64, 1)

    l6_2 = layers.ConvLSTM2D(10, 3, 1, 'same', return_sequences=True)(l5_2)  # (None, 5, 64, 64, 10)

    l7_2 = layers.ConvLSTM2D(20, 3, 1, 'same', return_sequences=True)(l6_2)  # (None, 5, 64, 64, 20)

    l8_2 = layers.ConvLSTM2D(20, 3, 1, 'same', return_sequences=True)(l7_2)  # (None, 5, 64, 64, 20)

    l9_2 = layers.ConvLSTM2D(10, 3, 1, 'same', return_sequences=False)(l8_2)  # (None, 64, 64, 10)

    l10_2 = layers.Conv2DTranspose(32, 3, 2, 'same')(l9_2)  # (None, 128, 128, 32)
    l10_2 = layers.BatchNormalization()(l10_2)
    l10_2 = layers.Dropout(0.3)(l10_2)
    l10_2 = layers.Activation('relu')(l10_2)

    l11_2 = layers.Conv2DTranspose(64, 3, 2, 'same')(l10_2)  # (None, 256, 256, 64)
    l11_2 = layers.BatchNormalization()(l11_2)
    l11_2 = layers.Dropout(0.3)(l11_2)
    l11_2 = layers.Activation('relu')(l11_2)

    l12 = layers.add([l11_1, l11_2])  # (None, 256, 256, 64)

    l13 = layers.Conv2D(1, 3, 1, 'same', activation='tanh')(l12)  # (None, 256, 256, 1)

    Model = tf.keras.Model(input, l13, name='generator')
    return Model
    for i in index:

        steps = cfg.tempro_steps * cfg.tempro_steps_interval
        image_x = images[i : i+steps*2 : cfg.tempro_steps_interval]
        # image_y = images[i+steps : i+2*steps: cfg.tempro_steps_interval]
        yield (image_x)

# %%
train_ds = tf.data.Dataset.from_generator(generator, output_types=(tf.float32)) \
                .shuffle(cfg.buffer_size).batch(cfg.batch_size) \
                .prefetch(tf.data.experimental.AUTOTUNE)

# %%
convlstm = keras.Sequential([
    layers.InputLayer(cfg.image_size),
    layers.ConvLSTM2D(32, 3, 2, 'same', return_sequences=True, use_bias=False),
    layers.BatchNormalization(),
    layers.ConvLSTM2D(64, 3, 2, 'same', return_sequences=True, use_bias=False),
    layers.BatchNormalization(),
    layers.ConvLSTM2D(128, 3, 2, 'same', return_sequences=True, use_bias=False),
    layers.BatchNormalization(),
    layers.ConvLSTM2D(128, 1, 1, 'same', return_sequences=True, use_bias=False),
    layers.BatchNormalization(),
    layers.Conv3DTranspose(64, 3, [1,2,2], 'same'),
    layers.BatchNormalization(),
    layers.ReLU(),
    layers.Conv3DTranspose(64, 1, 1, 'same', activation='relu'),
    layers.Conv3DTranspose(32, 3, [1,2,2], 'same'),
    layers.BatchNormalization(),
    layers.ReLU(),
    layers.Conv3DTranspose(32, 1, 1, 'same', activation='relu'),
def make_generator(image_size: int,
                   in_channels: int,
                   noise_channels: int,
                   out_channels: int,
                   n_timesteps: int,
                   batch_size: int = None,
                   feature_channels=128):
    # Make sure we have nice multiples everywhere
    assert image_size % 4 == 0
    assert feature_channels % 8 == 0
    total_in_channels = in_channels + noise_channels
    img_shape = (image_size, image_size)
    tshape = (n_timesteps, ) + img_shape
    input_image = kl.Input(shape=tshape + (in_channels, ),
                           batch_size=batch_size,
                           name='input_image')
    input_noise = kl.Input(shape=tshape + (noise_channels, ),
                           batch_size=batch_size,
                           name='input_noise')

    # Concatenate inputs
    x = kl.Concatenate()([input_image, input_noise])

    # Add features and decrease image size - in 2 steps
    intermediate_features = total_in_channels * 8 if total_in_channels * 8 <= feature_channels else feature_channels
    x = kl.TimeDistributed(kl.ZeroPadding2D(padding=3))(x)
    x = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(intermediate_features, (8, 8),
                      strides=2,
                      activation=LeakyReLU(0.2))))(x)
    x = kl.BatchNormalization()(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 2,
                              image_size // 2, intermediate_features)
    res_2 = x  # Keep residuals for later

    x = kl.TimeDistributed(kl.ZeroPadding2D(padding=1))(x)
    x = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(feature_channels, (4, 4),
                      strides=2,
                      activation=LeakyReLU(0.2))))(x)
    x = kl.BatchNormalization()(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 4,
                              image_size // 4, feature_channels)
    res_4 = x  # Keep residuals for later

    # Recurrent unit
    x = kl.ConvLSTM2D(feature_channels, (3, 3),
                      padding='same',
                      return_sequences=True)(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 4,
                              image_size // 4, feature_channels)

    # Re-increase image size and decrease features
    x = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(feature_channels // 2, (3, 3),
                      padding='same',
                      activation=LeakyReLU(0.2))))(x)
    x = kl.BatchNormalization()(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 4,
                              image_size // 4, feature_channels // 2)

    # Re-introduce residuals from before (skip connection)
    x = kl.Concatenate()([x, res_4])
    x = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2DTranspose(feature_channels / 4, (2, 2),
                               strides=2,
                               activation=LeakyReLU(0.2))))(x)
    x = kl.BatchNormalization()(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 2,
                              image_size // 2, feature_channels // 4)

    # Skip connection 2
    x = kl.Concatenate()([x, res_2])
    if feature_channels / 8 >= out_channels:
        x = kl.TimeDistributed(
            kl.UpSampling2D(size=(2, 2), interpolation='bilinear'))(x)
        x = kl.TimeDistributed(
            kl.Conv2DTranspose(feature_channels // 8, (5, 5),
                               padding='same',
                               activation=LeakyReLU(0.2)))(x)
        assert tuple(x.shape) == (batch_size, n_timesteps, image_size,
                                  image_size, feature_channels // 8)
    else:
        x = kl.TimeDistributed(
            kl.Conv2D(out_channels, (3, 3),
                      padding='same',
                      activation=LeakyReLU(0.2)))(x)
        assert tuple(x.shape) == (batch_size, n_timesteps, image_size,
                                  image_size, out_channels)
    x = kl.BatchNormalization()(x)
    x = kl.TimeDistributed(kl.Conv2D(out_channels, (3, 3),
                                     padding='same',
                                     activation='linear'),
                           name='predicted_image')(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size, image_size,
                              out_channels)
    return Model(inputs=[input_image, input_noise],
                 outputs=x,
                 name='generator')
def make_discriminator(low_res_size: int,
                       high_res_size: int,
                       low_res_channels: int,
                       high_res_channels: int,
                       n_timesteps: int,
                       batch_size: int = None,
                       feature_channels: int = 16):
    low_res = kl.Input(shape=(n_timesteps, low_res_size, low_res_size,
                              low_res_channels),
                       batch_size=batch_size,
                       name='low_resolution_image')
    high_res = kl.Input(shape=(n_timesteps, high_res_size, high_res_size,
                               high_res_channels),
                        batch_size=batch_size,
                        name='high_resolution_image')
    if tuple(low_res.shape)[:-1] != tuple(high_res.shape)[:-1]:
        raise NotImplementedError(
            "The discriminator assumes that the low res and high res images have the same size."
            "Perhaps you should upsample your low res image first?")
    # First branch: high res only
    hr = kl.ConvLSTM2D(high_res_channels, (3, 3),
                       padding='same',
                       return_sequences=True)(high_res)
    hr = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(feature_channels, (3, 3),
                      padding='same',
                      activation=LeakyReLU(0.2))))(hr)
    hr = kl.LayerNormalization()(hr)

    # Second branch: Mix both inputs
    mix = kl.Concatenate()([low_res, high_res])
    mix = kl.ConvLSTM2D(feature_channels, (3, 3),
                        padding='same',
                        return_sequences=True)(mix)
    mix = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(feature_channels, (3, 3),
                      padding='same',
                      activation=LeakyReLU(0.2))))(mix)
    mix = kl.LayerNormalization()(mix)

    # Merge everything together
    x = kl.Concatenate()([hr, mix])
    assert tuple(x.shape) == (batch_size, n_timesteps, low_res_size,
                              low_res_size, 2 * feature_channels)

    while img_size(x) >= 16:
        x = kl.TimeDistributed(kl.ZeroPadding2D())(x)
        x = kl.TimeDistributed(SpectralNormalization(
            kl.Conv2D(channels(x) * 2, (7, 7),
                      strides=3,
                      activation=LeakyReLU(0.2))),
                               name=f'conv_{img_size(x)}')(x)
        x = kl.LayerNormalization()(x)

    shortcut = x
    while img_size(x) >= 4:
        x = kl.TimeDistributed(kl.ZeroPadding2D())(x)
        x = kl.TimeDistributed(SpectralNormalization(
            kl.Conv2D(channels(x) * 2, (7, 7),
                      strides=3,
                      activation=LeakyReLU(0.2))),
                               name=f'conv_{img_size(x)}')(x)
        x = kl.LayerNormalization()(x)
    shortcut = shortcut_convolution(shortcut, x, channels(x))
    # Split connection
    x = kl.add([x, shortcut])

    while img_size(x) > 2:
        x = kl.TimeDistributed(SpectralNormalization(
            kl.Conv2D(channels(x) * 2, (3, 3),
                      strides=2,
                      activation=LeakyReLU(0.2))),
                               name=f'conv_{img_size(x)}')(x)
        x = kl.LayerNormalization()(x)
    x = kl.TimeDistributed(kl.Flatten())(x)
    assert tuple(x.shape)[:-1] == (batch_size, n_timesteps
                                   )  # Unknown number of channels
    x = kl.TimeDistributed(kl.Dense(1, activation='linear'))(x)
    x = kl.GlobalAveragePooling1D(name='score')(x)

    return Model(inputs=[low_res, high_res], outputs=x, name='discriminator')
Exemple #6
0
def generator():
    input = tf.keras.Input(shape=[L_node, W_node,
                                  Channel])  # (None, 256, 256, 1)

    e0 = layers.Conv2D(32, 3, 1, 'same',
                       activation='relu')(input)  # (None, 256, 256, 32)

    e1 = layers.Conv2D(64, 5, 2, 'same')(e0)  # (None, 128, 128, 64)
    e1_1 = layers.LeakyReLU(0.2)(e1)
    e1_1 = layers.Conv2D(64, 3, 1, 'same',
                         activation='relu')(e1_1)  # (None, 128, 128, 64)

    e2 = layers.Conv2D(128, 5, 2, 'same')(e1_1)  # (None, 64, 64, 128)
    e2 = layers.BatchNormalization()(e2)
    e2_1 = layers.LeakyReLU(0.2)(e2)
    e2_1 = layers.Conv2D(128, 3, 1, 'same',
                         activation='relu')(e2_1)  # (None, 64, 64, 128)

    e3 = layers.Conv2D(256, 5, 2, 'same')(e2_1)  # (None, 32, 32, 256)
    e3 = layers.BatchNormalization()(e3)
    e3 = layers.LeakyReLU(0.2)(e3)
    e3 = layers.Conv2D(256, 3, 1, 'same',
                       activation='relu')(e3)  # (None, 32, 32, 256)
    e3 = tf.reshape(e3, shape=[-1, Num_sequence, 32, 32, 256])

    c1 = layers.ConvLSTM2D(10, 3, 1, 'same', return_sequences=True)(e3)
    c1 = layers.ConvLSTM2D(20, 3, 1, 'same', return_sequences=True)(c1)
    c1 = layers.ConvLSTM2D(10, 3, 1, 'same', return_sequences=False)(c1)

    c2 = layers.Conv2D(256, 3, 1, 'same',
                       activation='relu')(c1)  # (None, 32, 32, 256)

    d1 = layers.Conv2DTranspose(128, 5, 2, 'same',
                                activation='relu')(c2)  # (None, 64, 64, 128)
    d1 = layers.BatchNormalization()(d1)
    d1 = layers.Dropout(0.5)(d1)
    d1 = layers.concatenate(
        [d1, e2[Batch_size * (Num_sequence - 1):Batch_size * Num_sequence, ]],
        3)  # (None, 64, 64, 256)
    d1 = layers.Activation('relu')(d1)
    d1 = layers.Conv2D(128, 3, 1, 'same', activation='relu')(d1)

    d2 = layers.Conv2DTranspose(64, 5, 2, 'same',
                                activation='relu')(d1)  # (None, 128, 128, 64)
    d2 = layers.BatchNormalization()(d2)
    d2 = layers.Dropout(0.5)(d2)
    d2 = layers.concatenate(
        [d2, e1[Batch_size * (Num_sequence - 1):Batch_size * Num_sequence, ]],
        3)  # (None, 128, 128, 128)
    d2 = layers.Activation('relu')(d2)
    d2 = layers.Conv2D(64, 3, 1, 'same', activation='relu')(d2)

    d3 = layers.Conv2DTranspose(64, 5, 2, 'same',
                                activation='relu')(d2)  # (None, 256, 256, 32)
    d3 = layers.BatchNormalization()(d3)
    d3 = layers.Dropout(0.5)(d3)
    d3 = layers.concatenate(
        [d3, e0[Batch_size * (Num_sequence - 1):Batch_size * Num_sequence, ]],
        3)  # (None, 256, 256, 64)
    d3 = layers.Activation('relu')(d3)
    d3 = layers.Conv2D(1, 3, 1, 'same',
                       activation='tanh')(d3)  # (None, 256, 256, 1)

    Model = tf.keras.Model(input, d3)
    return Model
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pylab as plt
"""
## Build a model
We create a model which take as input movies of shape
`(n_frames, width, height, channels)` and returns a movie
of identical shape.
"""

seq = keras.Sequential([
    keras.Input(shape=(None, 40, 40,
                       1)),  # Variable-length sequence of 40x40x1 frames
    layers.ConvLSTM2D(filters=40,
                      kernel_size=(3, 3),
                      padding="same",
                      return_sequences=True),
    layers.BatchNormalization(),
    layers.ConvLSTM2D(filters=40,
                      kernel_size=(3, 3),
                      padding="same",
                      return_sequences=True),
    layers.BatchNormalization(),
    layers.ConvLSTM2D(filters=40,
                      kernel_size=(3, 3),
                      padding="same",
                      return_sequences=True),
    layers.BatchNormalization(),
    layers.ConvLSTM2D(filters=40,
                      kernel_size=(3, 3),
                      padding="same",
Exemple #8
0
def main():
    arg_list = None
    args = parseArgs(arg_list)
    # grab training data
    filepath = 'data/train_sample_videos'
    datapath = os.path.join(filepath, 'metadata.json')
    data = pd.read_json(datapath).T
    if args.sample:
        files = [os.path.join(filepath, f) for f in data.index][:20]
        labels = data.label.values[:20]
    else:
        files = [os.path.join(filepath, f) for f in data.index]
        labels = data.label.values
    x_train, x_test, y_train, y_test = train_test_split(
        files, labels, test_size=float(args.test_split))
    class_weights = compute_class_weight(
        'balanced', np.unique(y_train), y_train)
    for k, v in zip(np.unique(y_train), class_weights):
        print(k, v)
    y_train = list(map(lambda x: 0 if x == 'REAL' else 1, y_train))
    y_test = list(map(lambda x: 0 if x == 'REAL' else 1, y_test))
    y_train = to_categorical(y_train, num_classes=2)
    y_test = to_categorical(y_test, num_classes=2)
    print(len(x_train), len(y_train), len(x_test), len(y_test))

    # validation data
    val_path = 'data/test_videos'
    if args.sample:
        val_files = [os.path.join(val_path, f)
                     for f in os.listdir(val_path)][:8]
    else:
        val_files = [os.path.join(val_path, f) for f in os.listdir(val_path)]
    print('number of validation files', len(val_files))

    # generate datasets
    batch_size = args.batch_size
    segment_size = args.segment_size
    rsz = (128, 128)
    train_data = input_fn(
        x_train,
        y_train,
        segment_size=segment_size,
        batch_size=batch_size,
        rsz=rsz)
    test_data = input_fn(
        x_test,
        y_test,
        segment_size=segment_size,
        batch_size=batch_size,
        rsz=rsz)
    val_data = input_fn(
        files=val_files,
        segment_size=segment_size,
        batch_size=batch_size,
        rsz=rsz)
    rgb_input = tf.keras.Input(
        shape=(segment_size, rsz[0], rsz[1], 3),
        name='rgb_input')
    flow_input = tf.keras.Input(
        shape=(segment_size - 1, rsz[0], rsz[1], 2),
        name='flow_input')

    # TODO: make OO
    # RGB MODEL
    # block 1
    x = layers.Conv3D(
        filters=8,
        kernel_size=3,
        strides=(1, 1, 1),
        padding='same',
        data_format='channels_last',
        activation='relu',
    )(rgb_input)
    x = layers.Conv3D(
        filters=8,
        kernel_size=4,
        strides=(1, 1, 1),
        padding='same',
        data_format='channels_last',
        activation='relu',
    )(x)
    block1_output = layers.MaxPool3D(
        pool_size=(2, 2, 2),
        strides=(2, 2, 2),
        padding='same'
    )(x)
    # block 2
    x = layers.Conv3D(
        filters=8,
        kernel_size=3,
        strides=(1, 1, 1),
        padding='same',
        data_format='channels_last',
        activation='relu',
    )(block1_output)
    x = layers.Conv3D(
        filters=8,
        kernel_size=4,
        strides=(1, 1, 1),
        padding='same',
        data_format='channels_last',
        activation='relu',
    )(x)
    block2_output = layers.add([x, block1_output])
    # block 3
    x = layers.Conv3D(
        filters=8,
        kernel_size=3,
        strides=(1, 1, 1),
        padding='same',
        data_format='channels_last',
        activation='relu',
    )(block2_output)
    x = layers.Conv3D(
        filters=8,
        kernel_size=4,
        strides=(1, 1, 1),
        padding='same',
        data_format='channels_last',
        activation='relu',
    )(x)
    block3_output = layers.add([x, block2_output])

    x = layers.Conv3D(
        filters=8,
        kernel_size=3,
        strides=(1, 1, 1),
        padding='same',
        data_format='channels_last',
        activation='relu',
    )(block3_output)
    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    rgb_outputs = layers.Dense(2, activation='softmax')(x)

    rgb_model = Model(inputs=rgb_input, outputs=rgb_outputs)
    rgb_model.summary()

    # FLOW MODEL
    x = layers.ConvLSTM2D(
        filters=8,
        kernel_size=3,
        strides=1,
        padding='same',
        data_format='channels_last',
        return_sequences=True,
        dropout=0.5
    )(flow_input)
    x = layers.BatchNormalization()(x)
    x = layers.ConvLSTM2D(
        filters=8,
        kernel_size=3,
        strides=1,
        padding='same',
        data_format='channels_last',
        return_sequences=True,
        dropout=0.5
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.ConvLSTM2D(
        filters=8,
        kernel_size=3,
        strides=1,
        padding='same',
        data_format='channels_last',
        return_sequences=False,
        dropout=0.5
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    flow_output = layers.Dense(2)(x)
    flow_model = Model(inputs=flow_input, outputs=flow_output)
    flow_model.summary()

    # FINAL MODEL
    final_average = layers.average([rgb_outputs, flow_output])
    x = layers.Flatten()(final_average)
    final_output = layers.Dense(
        2, activation='softmax', name='final_output')(x)
    model = Model(
        inputs={"rgb_input": rgb_input, "flow_input": flow_input},
        outputs=final_output,
        name='my_model'
    )
    model.summary()

    # tf.keras.utils.plot_model(
    #     model,
    #     to_file='model.png',
    #     show_shapes=True,
    #     show_layer_names=True
    # )

    # TRAIN
    dt = datetime.now().strftime('%Y%m%d_%H%M%S')
    opt = tf.keras.optimizers.Adam()
    if args.save_checkpoints:
        save_path = f'data/model_checkpoints/{dt}/ckpt'
        ckpt = tf.keras.callbacks.ModelCheckpoint(
            filepath=save_path,
            save_best_only=False,
            save_weights_only=True
        )
        ckpt = [ckpt]
    else:
        ckpt = []
    model.compile(
        optimizer=opt,
        loss='categorical_crossentropy',
        metrics=['acc'])
    model.fit(
        x=train_data.repeat(),
        validation_data=test_data.repeat(),
        epochs=args.epochs,
        verbose=1,
        class_weight=class_weights,
        steps_per_epoch=len(x_train) // batch_size,
        validation_steps=len(x_test) // batch_size,
        callbacks=ckpt
    )

    # EVAL
    print('\n\n---------------------------------------------------------')
    print('predicting on validation data')
    start = time.time()
    preds = model.predict(
        val_data,
        verbose=1,
        steps=len(val_files) // batch_size
    )
    print('prediction time: ', time.time() - start)
    preds = np.argmax(preds, axis=1)
    df = pd.DataFrame(columns=['filename', 'label'])
    df.filename = [v.split('/')[-1] for v in val_files]
    df.label = preds
    df.to_csv(f'data/submission_{dt}.csv', index=False)