Exemple #1
0
def test_TimeDistributed_with_masking_layer():
    # test with Masking layer
    model = Sequential()
    model.add(
        wrappers.TimeDistributed(layers.Masking(mask_value=0., ),
                                 input_shape=(None, 4)))
    model.add(wrappers.TimeDistributed(layers.Dense(5)))
    model.compile(optimizer='rmsprop', loss='mse')
    model_input = np.random.randint(low=1, high=5, size=(10, 3, 4))
    for i in range(4):
        model_input[i, i:, :] = 0.
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(model_input,
              np.random.random((10, 3, 5)),
              epochs=1,
              batch_size=6)
    mask_outputs = [
        model.layers[0].compute_mask(model.input, compute_mask=True)
    ]
    mask_outputs += [
        model.layers[1].compute_mask(model.layers[1].input,
                                     mask_outputs[-1],
                                     compute_mask=True)
    ]
    func = K.function([model.input], mask_outputs)
    mask_outputs_val = func([model_input])
    assert np.array_equal(mask_outputs_val[0], np.any(model_input, axis=-1))
    assert np.array_equal(mask_outputs_val[1], np.any(model_input, axis=-1))
Exemple #2
0
def test_TimeDistributed_with_masked_embedding_and_unspecified_shape():
    # test with unspecified shape and Embeddings with mask_zero
    model = Sequential()
    model.add(
        wrappers.TimeDistributed(layers.Embedding(5, 6, mask_zero=True),
                                 input_shape=(None, None)))
    # the shape so far: (N, t_1, t_2, 6)
    model.add(
        wrappers.TimeDistributed(layers.SimpleRNN(7, return_sequences=True)))
    model.add(
        wrappers.TimeDistributed(layers.SimpleRNN(8, return_sequences=False)))
    model.add(layers.SimpleRNN(1, return_sequences=False))
    model.compile(optimizer='rmsprop', loss='mse')
    model_input = np.random.randint(low=1,
                                    high=5,
                                    size=(10, 3, 4),
                                    dtype='int32')
    for i in range(4):
        model_input[i, i:, i:] = 0
    model.fit(model_input, np.random.random((10, 1)), epochs=1, batch_size=10)
    mask_outputs = [model.layers[0].compute_mask(model.input)]
    for layer in model.layers[1:]:
        mask_outputs.append(layer.compute_mask(layer.input, mask_outputs[-1]))
    func = K.function([model.input], mask_outputs[:-1])
    mask_outputs_val = func([model_input])
    ref_mask_val_0 = model_input > 0  # embedding layer
    ref_mask_val_1 = ref_mask_val_0  # first RNN layer
    ref_mask_val_2 = np.any(ref_mask_val_1, axis=-1)  # second RNN layer
    ref_mask_val = [ref_mask_val_0, ref_mask_val_1, ref_mask_val_2]
    for i in range(3):
        assert np.array_equal(mask_outputs_val[i], ref_mask_val[i])
    assert mask_outputs[-1] is None  # final layer
Exemple #3
0
def lstm_readout_net_old(feature_map_in_seqs,
                         gaze_map_size,
                         drop_rate,
                         gaze_prior=None):
    x = wps.TimeDistributed(\
            layers.Conv2D(16, (1, 1), activation='relu',
                          name='readout_conv1'))(feature_map_in_seqs)
    x = wps.TimeDistributed(layers.core.Dropout(drop_rate))(x)
    x = wps.TimeDistributed(\
            layers.Conv2D(32, (1, 1), activation='relu',
                          name='readout_conv2'))(x)
    x = wps.TimeDistributed(layers.core.Dropout(drop_rate))(x)
    x = wps.TimeDistributed(\
            layers.Conv2D(2, (1, 1), activation='relu',
                          name='readout_conv3'))(x)
    x = wps.TimeDistributed(layers.core.Dropout(drop_rate))(x)

    x = wps.TimeDistributed(layers.core.Reshape((-1, )))(x)

    x = layers.recurrent.LSTM(units=gaze_map_size[0] * gaze_map_size[1],
                              dropout=drop_rate,
                              recurrent_dropout=drop_rate,
                              return_sequences=True)(x)

    x = wps.TimeDistributed(
        layers.core.Dense(gaze_map_size[0] * gaze_map_size[1]))(x)

    x = wps.TimeDistributed(layers.core.Reshape(gaze_map_size + (1, )))(x)

    x = wps.TimeDistributed(\
            GaussianSmooth(kernel_size = GAUSSIAN_KERNEL_SIZE, name='gaussian_smooth'))(x)

    logits = tf.reshape(x, [-1, gaze_map_size[0] * gaze_map_size[1]])

    #gaze prior map
    if gaze_prior is not None:
        #predicted annotation before adding prior
        pre_prior_logits = logits

        gaze_prior = np.maximum(gaze_prior,
                                EPSILON * np.ones(gaze_prior.shape))
        gaze_prior = gaze_prior.astype(np.float32)
        log_prior = np.log(gaze_prior)
        log_prior_1d = np.reshape(log_prior, (1, -1))
        log_prior_unit_tensor = tf.constant(log_prior_1d)
        log_prior_tensor = tf.matmul(
            tf.ones((tf.shape(pre_prior_logits)[0], 1)), log_prior_unit_tensor)
        log_prior_tensor = tf.reshape(
            log_prior_tensor, [-1, gaze_map_size[0] * gaze_map_size[1]])
        logits = tf.add(pre_prior_logits, log_prior_tensor)

    if gaze_prior is None:
        return logits
    else:
        return logits, pre_prior_logits
Exemple #4
0
def test_TimeDistributed_learning_phase():
    # test layers that need learning_phase to be set
    x = Input(shape=(3, 2))
    y = wrappers.TimeDistributed(core.Dropout(.999))(x, training=True)
    model = Model(x, y)
    y = model.predict(np.random.random((10, 3, 2)))
    assert_allclose(0., y, atol=1e-2)
Exemple #5
0
def test_TimeDistributed_learning_phase():
    # test layers that need learning_phase to be set
    np.random.seed(1234)
    x = Input(shape=(3, 2))
    y = wrappers.TimeDistributed(layers.Dropout(.999))(x, training=True)
    model = Model(x, y)
    y = model.predict(np.random.random((10, 3, 2)))
    assert_allclose(np.mean(y), 0., atol=1e-1, rtol=1e-1)
Exemple #6
0
def test_regularizers():
    model = Sequential()
    model.add(
        wrappers.TimeDistributed(core.Dense(2, W_regularizer='l1'),
                                 input_shape=(3, 4)))
    model.add(core.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    assert len(model.losses) == 1
Exemple #7
0
def test_regularizers():
    model = Sequential()
    model.add(wrappers.TimeDistributed(
        layers.Dense(2, kernel_regularizer='l1'), input_shape=(3, 4)))
    model.add(layers.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    assert len(model.layers[0].layer.losses) == 1
    assert len(model.layers[0].losses) == 1
    assert len(model.layers[0].get_losses_for(None)) == 1
    assert len(model.losses) == 1

    model = Sequential()
    model.add(wrappers.TimeDistributed(
        layers.Dense(2, activity_regularizer='l1'), input_shape=(3, 4)))
    model.add(layers.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    assert len(model.losses) == 1
Exemple #8
0
def big_conv_lstm_readout_net(feature_map_in_seqs, feature_map_size, drop_rate, gaze_prior=None):
    batch_size = tf.shape(feature_map_in_seqs)[0]
    n_step = tf.shape(feature_map_in_seqs)[1]
    n_channel = int(feature_map_in_seqs.get_shape()[4])
    feature_map = tf.reshape(feature_map_in_seqs,  
                             [batch_size*n_step, feature_map_size[0], 
                              feature_map_size[1], n_channel])
    
    x = layers.Conv2D(32, (1, 1), activation='relu', name='readout_conv1')(feature_map)
    x = layers.core.Dropout(drop_rate)(x)
    x = layers.Conv2D(16, (1, 1), activation='relu', name='readout_conv2')(x)
    x = layers.core.Dropout(drop_rate)(x)
    x = layers.Conv2D(8, (1, 1), activation='relu', name='readout_conv3')(x)
    x = layers.core.Dropout(drop_rate)(x)
    x = layers.Conv2D(1, (1, 1), activation='relu', name='readout_conv4')(x)
    x = layers.core.Dropout(drop_rate)(x)
    
    #x = layers.core.Reshape((-1,))(x)
    temp_shape = x.get_shape()[1:4]
    temp_shape = [int(s) for s in temp_shape]
    x = tf.reshape(x, [batch_size, n_step, temp_shape[0], temp_shape[1], temp_shape[2]])
    
    x = layers.ConvLSTM2D(filters=1,
                          kernel_size=(3,3),
                          strides=(1,1),
                          padding='same', 
                          dropout=drop_rate, 
                          recurrent_dropout=drop_rate,
                          return_sequences=True)(x)
    
    x = wps.TimeDistributed(layers.Conv2D(1, (1, 1), activation='linear'))(x)
    
    x = tf.reshape(x, [batch_size*n_step, 
                       feature_map_size[0], feature_map_size[1], 1])
        
    x = GaussianSmooth(kernel_size = GAUSSIAN_KERNEL_SIZE, name='gaussian_smooth')(x)
    
    logits = tf.reshape(x, [-1, feature_map_size[0]*feature_map_size[1]])
    
    #gaze prior map
    if gaze_prior is not None:
        #predicted annotation before adding prior
        pre_prior_logits = logits

        gaze_prior = np.maximum(gaze_prior, EPSILON*np.ones(gaze_prior.shape))
        gaze_prior = gaze_prior.astype(np.float32)
        log_prior = np.log(gaze_prior)
        log_prior_1d = np.reshape(log_prior, (1, -1))
        log_prior_unit_tensor = tf.constant(log_prior_1d)
        log_prior_tensor = tf.matmul(tf.ones((tf.shape(pre_prior_logits)[0],1)), log_prior_unit_tensor)
        log_prior_tensor = tf.reshape(log_prior_tensor, 
                                      [-1, feature_map_size[0]*feature_map_size[1]])
        logits = tf.add(pre_prior_logits, log_prior_tensor)
    
    if gaze_prior is None:
        return logits
    else:
        return logits, pre_prior_logits
Exemple #9
0
def test_TimeDistributed_trainable():
    # test layers that need learning_phase to be set
    x = Input(shape=(3, 2))
    layer = wrappers.TimeDistributed(layers.BatchNormalization())
    _ = layer(x)
    assert len(layer.trainable_weights) == 2
    layer.trainable = False
    assert len(layer.trainable_weights) == 0
    layer.trainable = True
    assert len(layer.trainable_weights) == 2
Exemple #10
0
def test_TimeDistributed():
    # first, test with Dense layer
    model = Sequential()
    model.add(wrappers.TimeDistributed(layers.Dense(2), input_shape=(3, 4)))
    model.add(layers.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(np.random.random((10, 3, 4)),
              np.random.random((10, 3, 2)),
              epochs=1,
              batch_size=10)

    # test config
    model.get_config()

    # test when specifying a batch_input_shape
    test_input = np.random.random((1, 3, 4))
    test_output = model.predict(test_input)
    weights = model.layers[0].get_weights()

    reference = Sequential()
    reference.add(
        wrappers.TimeDistributed(layers.Dense(2), batch_input_shape=(1, 3, 4)))
    reference.add(layers.Activation('relu'))
    reference.compile(optimizer='rmsprop', loss='mse')
    reference.layers[0].set_weights(weights)

    reference_output = reference.predict(test_input)
    assert_allclose(test_output, reference_output, atol=1e-05)

    # test with Embedding
    model = Sequential()
    model.add(
        wrappers.TimeDistributed(layers.Embedding(5, 6),
                                 batch_input_shape=(10, 3, 4),
                                 dtype='int32'))
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(np.random.randint(5, size=(10, 3, 4), dtype='int32'),
              np.random.random((10, 3, 4, 6)),
              epochs=1,
              batch_size=10)

    # compare to not using batch_input_shape
    test_input = np.random.randint(5, size=(10, 3, 4), dtype='int32')
    test_output = model.predict(test_input)
    weights = model.layers[0].get_weights()

    reference = Sequential()
    reference.add(
        wrappers.TimeDistributed(layers.Embedding(5, 6),
                                 input_shape=(3, 4),
                                 dtype='int32'))
    reference.compile(optimizer='rmsprop', loss='mse')
    reference.layers[0].set_weights(weights)

    reference_output = reference.predict(test_input)
    assert_allclose(test_output, reference_output, atol=1e-05)

    # test with Conv2D
    model = Sequential()
    model.add(
        wrappers.TimeDistributed(layers.Conv2D(5, (2, 2), padding='same'),
                                 input_shape=(2, 4, 4, 3)))
    model.add(layers.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.random.random((1, 2, 4, 4, 3)),
                         np.random.random((1, 2, 4, 4, 5)))

    model = model_from_json(model.to_json())
    model.summary()

    # test stacked layers
    model = Sequential()
    model.add(wrappers.TimeDistributed(layers.Dense(2), input_shape=(3, 4)))
    model.add(wrappers.TimeDistributed(layers.Dense(3)))
    model.add(layers.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')

    model.fit(np.random.random((10, 3, 4)),
              np.random.random((10, 3, 3)),
              epochs=1,
              batch_size=10)

    # test wrapping Sequential model
    model = Sequential()
    model.add(layers.Dense(3, input_dim=2))
    outer_model = Sequential()
    outer_model.add(wrappers.TimeDistributed(model, input_shape=(3, 2)))
    outer_model.compile(optimizer='rmsprop', loss='mse')
    outer_model.fit(np.random.random((10, 3, 2)),
                    np.random.random((10, 3, 3)),
                    epochs=1,
                    batch_size=10)

    # test with functional API
    x = Input(shape=(3, 2))
    y = wrappers.TimeDistributed(model)(x)
    outer_model = Model(x, y)
    outer_model.compile(optimizer='rmsprop', loss='mse')
    outer_model.fit(np.random.random((10, 3, 2)),
                    np.random.random((10, 3, 3)),
                    epochs=1,
                    batch_size=10)

    # test with BatchNormalization
    model = Sequential()
    model.add(
        wrappers.TimeDistributed(layers.BatchNormalization(center=True,
                                                           scale=True),
                                 name='bn',
                                 input_shape=(10, 2)))
    model.compile(optimizer='rmsprop', loss='mse')
    # Assert that mean and variance are 0 and 1.
    td = model.layers[0]
    assert np.array_equal(td.get_weights()[2], np.array([0, 0]))
    assert np.array_equal(td.get_weights()[3], np.array([1, 1]))
    # Train
    model.train_on_batch(np.random.normal(loc=2, scale=2, size=(1, 10, 2)),
                         np.broadcast_to(np.array([0, 1]), (1, 10, 2)))
    # Assert that mean and variance changed.
    assert not np.array_equal(td.get_weights()[2], np.array([0, 0]))
    assert not np.array_equal(td.get_weights()[3], np.array([1, 1]))
    # Verify input_map has one mapping from inputs to reshaped inputs.
    uid = object_list_uid(model.inputs)
    assert len(td._input_map.keys()) == 1
    assert uid in td._input_map
    assert K.int_shape(td._input_map[uid]) == (None, 2)
Exemple #11
0
def test_TimeDistributed():
    # first, test with Dense layer
    model = Sequential()
    model.add(wrappers.TimeDistributed(core.Dense(2), input_shape=(3, 4)))
    model.add(core.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(np.random.random((10, 3, 4)),
              np.random.random((10, 3, 2)),
              nb_epoch=1,
              batch_size=10)

    # test config
    model.get_config()

    # compare to TimeDistributedDense
    test_input = np.random.random((1, 3, 4))
    test_output = model.predict(test_input)
    weights = model.layers[0].get_weights()

    reference = Sequential()
    reference.add(
        core.TimeDistributedDense(2, input_shape=(3, 4), weights=weights))
    reference.add(core.Activation('relu'))
    reference.compile(optimizer='rmsprop', loss='mse')

    reference_output = reference.predict(test_input)
    assert_allclose(test_output, reference_output, atol=1e-05)

    # test when specifying a batch_input_shape
    reference = Sequential()
    reference.add(
        core.TimeDistributedDense(2,
                                  batch_input_shape=(1, 3, 4),
                                  weights=weights))
    reference.add(core.Activation('relu'))
    reference.compile(optimizer='rmsprop', loss='mse')

    reference_output = reference.predict(test_input)
    assert_allclose(test_output, reference_output, atol=1e-05)

    # test with Convolution2D
    model = Sequential()
    model.add(
        wrappers.TimeDistributed(convolutional.Convolution2D(
            5, 2, 2, border_mode='same'),
                                 input_shape=(2, 4, 4, 3)))
    model.add(core.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.random.random((1, 2, 4, 4, 3)),
                         np.random.random((1, 2, 4, 4, 5)))

    model = model_from_json(model.to_json())
    model.summary()

    # test stacked layers
    model = Sequential()
    model.add(wrappers.TimeDistributed(core.Dense(2), input_shape=(3, 4)))
    model.add(wrappers.TimeDistributed(core.Dense(3)))
    model.add(core.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')

    model.fit(np.random.random((10, 3, 4)),
              np.random.random((10, 3, 3)),
              nb_epoch=1,
              batch_size=10)

    # test wrapping Sequential model
    model = Sequential()
    model.add(core.Dense(3, input_dim=2))
    outer_model = Sequential()
    outer_model.add(wrappers.TimeDistributed(model, input_shape=(3, 2)))
    outer_model.compile(optimizer='rmsprop', loss='mse')
    outer_model.fit(np.random.random((10, 3, 2)),
                    np.random.random((10, 3, 3)),
                    nb_epoch=1,
                    batch_size=10)

    # test with functional API
    x = Input(shape=(3, 2))
    y = wrappers.TimeDistributed(model)(x)
    outer_model = Model(x, y)
    outer_model.compile(optimizer='rmsprop', loss='mse')
    outer_model.fit(np.random.random((10, 3, 2)),
                    np.random.random((10, 3, 3)),
                    nb_epoch=1,
                    batch_size=10)
Exemple #12
0
def thick_conv_lstm_readout_net(feature_map_in_seqs,
                                feature_map_size,
                                drop_rate,
                                gaze_prior=None,
                                output_embedding=False):
    batch_size = tf.shape(feature_map_in_seqs)[0]
    n_step = tf.shape(feature_map_in_seqs)[1]
    n_channel = int(feature_map_in_seqs.get_shape()[4])
    feature_map = tf.reshape(feature_map_in_seqs, [
        batch_size * n_step, feature_map_size[0], feature_map_size[1],
        n_channel
    ])

    x = layers.Conv2D(16, (1, 1), activation='relu',
                      name='readout_conv1')(feature_map)
    x = layers.BatchNormalization()(x)
    x = layers.core.Dropout(drop_rate)(x)
    x = layers.Conv2D(32, (1, 1), activation='relu', name='readout_conv2')(x)
    x = layers.BatchNormalization()(x)
    x = layers.core.Dropout(drop_rate)(x)
    x = layers.Conv2D(8, (1, 1), activation='relu', name='readout_conv3')(x)
    x = layers.BatchNormalization()(x)

    # reshape into temporal sequence
    temp_shape = x.get_shape()[1:4]
    temp_shape = [int(s) for s in temp_shape]
    x = tf.reshape(
        x, [batch_size, n_step, temp_shape[0], temp_shape[1], temp_shape[2]])

    n_channel = 15

    initial_c = layers.Conv2D(n_channel, (3, 3),
                              activation='tanh',
                              padding='same')(layers.core.Dropout(drop_rate)(
                                  x[:, 0]))
    initial_c = layers.core.Dropout(drop_rate)(initial_c)
    initial_h = layers.Conv2D(n_channel, (3, 3),
                              activation='tanh',
                              padding='same')(layers.core.Dropout(drop_rate)(
                                  x[:, 0]))
    initial_h = layers.core.Dropout(drop_rate)(initial_h)

    conv_lstm = layers.ConvLSTM2D(filters=n_channel,
                                  kernel_size=(3, 3),
                                  strides=(1, 1),
                                  padding='same',
                                  dropout=drop_rate,
                                  recurrent_dropout=drop_rate,
                                  return_sequences=True)
    x = conv_lstm([x, initial_c, initial_h])
    x = wps.TimeDistributed(
        layers.Conv2D(n_channel, (1, 1), activation='linear'))(x)
    x = tf.reshape(x, [
        batch_size * n_step, feature_map_size[0], feature_map_size[1],
        n_channel
    ])
    x = layers.BatchNormalization()(x)
    embed = x

    x = layers.Conv2D(1, (1, 1), activation='linear')(x)

    x = tf.reshape(
        x, [batch_size * n_step, feature_map_size[0], feature_map_size[1], 1])
    raw_logits = tf.reshape(x, [-1, feature_map_size[0] * feature_map_size[1]])

    logits = tf.reshape(x, [-1, feature_map_size[0] * feature_map_size[1]])

    #gaze prior map
    if gaze_prior is not None:
        #predicted annotation before adding prior
        pre_prior_logits = logits

        gaze_prior = np.maximum(gaze_prior,
                                EPSILON * np.ones(gaze_prior.shape))
        gaze_prior = gaze_prior.astype(np.float32)
        log_prior = np.log(gaze_prior)
        log_prior_1d = np.reshape(log_prior, (1, -1))
        log_prior_unit_tensor = tf.constant(log_prior_1d)
        log_prior_tensor = tf.matmul(
            tf.ones((tf.shape(pre_prior_logits)[0], 1)), log_prior_unit_tensor)
        log_prior_tensor = tf.reshape(
            log_prior_tensor, [-1, feature_map_size[0] * feature_map_size[1]])
        logits = tf.add(pre_prior_logits, log_prior_tensor)

    if output_embedding:
        return logits, embed, raw_logits

    if gaze_prior is None:
        return logits
    else:
        return logits, pre_prior_logits
Exemple #13
0
def conv_lstm_planner(peripheral_feature_map_seqs, foveal_feature_seqs,
                      drop_rate):
    # combine feature maps
    if foveal_feature_seqs is None:
        feature_map_seqs = peripheral_feature_map_seqs
    elif peripheral_feature_map_seqs is None:
        feature_map_seqs = foveal_feature_seqs
    else:
        feature_map_seqs = tf.concat(
            [peripheral_feature_map_seqs, foveal_feature_seqs], axis=-1)
    # get the shape
    batch_size = tf.shape(feature_map_seqs)[0]
    n_step = tf.shape(feature_map_seqs)[1]
    temp_shape = feature_map_seqs.get_shape()[2:5]
    temp_shape = [int(s) for s in temp_shape]
    feature_map_size = temp_shape[0:2]
    n_channel = temp_shape[2]

    conv_lstm = layers.ConvLSTM2D(filters=5,
                                  kernel_size=(3, 3),
                                  strides=(1, 1),
                                  padding='same',
                                  dropout=drop_rate,
                                  recurrent_dropout=drop_rate,
                                  return_sequences=True)

    initial_c = layers.Conv2D(5, (3, 3), activation='tanh', padding='same')(
        layers.core.Dropout(drop_rate)(feature_map_seqs[:, 0]))
    initial_c = layers.core.Dropout(drop_rate)(initial_c)
    initial_h = layers.Conv2D(5, (3, 3), activation='tanh', padding='same')(
        layers.core.Dropout(drop_rate)(feature_map_seqs[:, 0]))
    initial_h = layers.core.Dropout(drop_rate)(initial_h)
    x = conv_lstm([feature_map_seqs, initial_c, initial_h])

    # track weights
    kernel_weights = conv_lstm.weights[0]  # shape is [3, 3, 8+8, 5*4]
    if peripheral_feature_map_seqs is None:
        peripheral_weights = None
        peripheral_n_channels = 0
    else:
        peripheral_n_channels = tf.shape(peripheral_feature_map_seqs)[-1]
        peripheral_weights = kernel_weights[:, :, 0:peripheral_n_channels, :]

    if foveal_feature_seqs is None:
        foveal_weights = None
    else:
        foveal_weights = kernel_weights[:, :, peripheral_n_channels:, :]

    x = tf.reshape(
        x, [batch_size * n_step, feature_map_size[0], feature_map_size[1], 5])
    x = layers.BatchNormalization()(x)

    temp_shape = x.get_shape()[1:4]
    temp_shape = [int(s) for s in temp_shape]
    x = tf.reshape(
        x, [batch_size, n_step, temp_shape[0] * temp_shape[1] * temp_shape[2]])

    x = wps.TimeDistributed(layers.Dense(units=512, activation='linear'))(x)

    logits = tf.reshape(x, [batch_size * n_step, 512])

    return logits, peripheral_weights, foveal_weights
class LayerCorrectnessTest(keras_parameterized.TestCase):
    def setUp(self):
        super(LayerCorrectnessTest, self).setUp()
        # Set two virtual CPUs to test MirroredStrategy with multiple devices
        cpus = tf.config.list_physical_devices('CPU')
        tf.config.set_logical_device_configuration(cpus[0], [
            tf.config.LogicalDeviceConfiguration(),
            tf.config.LogicalDeviceConfiguration(),
        ])

    def _create_model_from_layer(self, layer, input_shapes):
        inputs = [layers.Input(batch_input_shape=s) for s in input_shapes]
        if len(inputs) == 1:
            inputs = inputs[0]
        y = layer(inputs)
        model = models.Model(inputs, y)
        model.compile('sgd', 'mse')
        return model

    @parameterized.named_parameters(
        ('LeakyReLU', advanced_activations.LeakyReLU, (2, 2)),
        ('PReLU', advanced_activations.PReLU, (2, 2)),
        ('ELU', advanced_activations.ELU, (2, 2)),
        ('ThresholdedReLU', advanced_activations.ThresholdedReLU, (2, 2)),
        ('Softmax', advanced_activations.Softmax, (2, 2)),
        ('ReLU', advanced_activations.ReLU, (2, 2)),
        ('Conv1D', lambda: convolutional.Conv1D(2, 2), (2, 2, 1)),
        ('Conv2D', lambda: convolutional.Conv2D(2, 2), (2, 2, 2, 1)),
        ('Conv3D', lambda: convolutional.Conv3D(2, 2), (2, 2, 2, 2, 1)),
        ('Conv2DTranspose', lambda: convolutional.Conv2DTranspose(2, 2),
         (2, 2, 2, 2)),
        ('SeparableConv2D', lambda: convolutional.SeparableConv2D(2, 2),
         (2, 2, 2, 1)),
        ('DepthwiseConv2D', lambda: convolutional.DepthwiseConv2D(2, 2),
         (2, 2, 2, 1)),
        ('UpSampling2D', convolutional.UpSampling2D, (2, 2, 2, 1)),
        ('ZeroPadding2D', convolutional.ZeroPadding2D, (2, 2, 2, 1)),
        ('Cropping2D', convolutional.Cropping2D, (2, 3, 3, 1)),
        ('ConvLSTM2D',
         lambda: convolutional_recurrent.ConvLSTM2D(4, kernel_size=(2, 2)),
         (4, 4, 4, 4, 4)),
        ('Dense', lambda: core.Dense(2), (2, 2)),
        ('Dropout', lambda: core.Dropout(0.5), (2, 2)),
        ('SpatialDropout2D', lambda: core.SpatialDropout2D(0.5), (2, 2, 2, 2)),
        ('Activation', lambda: core.Activation('sigmoid'), (2, 2)),
        ('Reshape', lambda: core.Reshape((1, 4, 1)), (2, 2, 2)),
        ('Permute', lambda: core.Permute((2, 1)), (2, 2, 2)),
        ('Attention', dense_attention.Attention, [(2, 2, 3), (2, 3, 3),
                                                  (2, 3, 3)]),
        ('AdditiveAttention', dense_attention.AdditiveAttention, [(2, 2, 3),
                                                                  (2, 3, 3),
                                                                  (2, 3, 3)]),
        ('Embedding', lambda: embeddings.Embedding(4, 4),
         (2, 4), 2e-3, 2e-3, np.random.randint(4, size=(2, 4))),
        ('LocallyConnected1D', lambda: local.LocallyConnected1D(2, 2),
         (2, 2, 1)),
        ('LocallyConnected2D', lambda: local.LocallyConnected2D(2, 2),
         (2, 2, 2, 1)),
        ('Add', merge.Add, [(2, 2), (2, 2)]),
        ('Subtract', merge.Subtract, [(2, 2), (2, 2)]),
        ('Multiply', merge.Multiply, [(2, 2), (2, 2)]),
        ('Average', merge.Average, [(2, 2), (2, 2)]),
        ('Maximum', merge.Maximum, [(2, 2), (2, 2)]),
        ('Minimum', merge.Minimum, [(2, 2), (2, 2)]),
        ('Concatenate', merge.Concatenate, [(2, 2), (2, 2)]),
        ('Dot', lambda: merge.Dot(1), [(2, 2), (2, 2)]),
        ('GaussianNoise', lambda: noise.GaussianNoise(0.5), (2, 2)),
        ('GaussianDropout', lambda: noise.GaussianDropout(0.5), (2, 2)),
        ('AlphaDropout', lambda: noise.AlphaDropout(0.5), (2, 2)),
        ('BatchNormalization', normalization_v2.BatchNormalization,
         (2, 2), 1e-2, 1e-2),
        ('LayerNormalization', normalization.LayerNormalization, (2, 2)),
        ('LayerNormalizationUnfused',
         lambda: normalization.LayerNormalization(axis=1), (2, 2, 2)),
        ('MaxPooling2D', pooling.MaxPooling2D, (2, 2, 2, 1)),
        ('AveragePooling2D', pooling.AveragePooling2D, (2, 2, 2, 1)),
        ('GlobalMaxPooling2D', pooling.GlobalMaxPooling2D, (2, 2, 2, 1)),
        ('GlobalAveragePooling2D', pooling.GlobalAveragePooling2D,
         (2, 2, 2, 1)),
        ('SimpleRNN', lambda: recurrent.SimpleRNN(units=4),
         (4, 4, 4), 1e-2, 1e-2),
        ('GRU', lambda: recurrent.GRU(units=4), (4, 4, 4)),
        ('LSTM', lambda: recurrent.LSTM(units=4), (4, 4, 4)),
        ('GRUV2', lambda: recurrent_v2.GRU(units=4), (4, 4, 4)),
        ('LSTMV2', lambda: recurrent_v2.LSTM(units=4), (4, 4, 4)),
        ('TimeDistributed', lambda: wrappers.TimeDistributed(core.Dense(2)),
         (2, 2, 2)),
        ('Bidirectional',
         lambda: wrappers.Bidirectional(recurrent.SimpleRNN(units=4)),
         (2, 2, 2)),
        ('AttentionLayerCausal',
         lambda: dense_attention.Attention(causal=True), [(2, 2, 3), (2, 3, 3),
                                                          (2, 3, 3)]),
        ('AdditiveAttentionLayerCausal',
         lambda: dense_attention.AdditiveAttention(causal=True), [(2, 3, 4),
                                                                  (2, 3, 4),
                                                                  (2, 3, 4)]),
    )
    def test_layer(self,
                   f32_layer_fn,
                   input_shape,
                   rtol=2e-3,
                   atol=2e-3,
                   input_data=None):
        """Tests a layer by comparing the float32 and mixed precision weights.

    A float32 layer, a mixed precision layer, and a distributed mixed precision
    layer are run. The three layers are identical other than their dtypes and
    distribution strategies. The outputs after predict() and weights after fit()
    are asserted to be close.

    Args:
      f32_layer_fn: A function returning a float32 layer. The other two layers
        will automatically be created from this
      input_shape: The shape of the input to the layer, including the batch
        dimension. Or a list of shapes if the layer takes multiple inputs.
      rtol: The relative tolerance to be asserted.
      atol: The absolute tolerance to be asserted.
      input_data: A Numpy array with the data of the input. If None, input data
        will be randomly generated
    """

        if f32_layer_fn == convolutional.ZeroPadding2D and \
           tf.test.is_built_with_rocm():
            return
        if isinstance(input_shape[0], int):
            input_shapes = [input_shape]
        else:
            input_shapes = input_shape
        strategy = create_mirrored_strategy()
        f32_layer = f32_layer_fn()

        # Create the layers
        assert f32_layer.dtype == f32_layer._compute_dtype == 'float32'
        config = f32_layer.get_config()
        config['dtype'] = policy.Policy('mixed_float16')
        mp_layer = f32_layer.__class__.from_config(config)
        distributed_mp_layer = f32_layer.__class__.from_config(config)

        # Compute per_replica_input_shapes for the distributed model
        global_batch_size = input_shapes[0][0]
        assert global_batch_size % strategy.num_replicas_in_sync == 0, (
            'The number of replicas, %d, does not divide the global batch size of '
            '%d' % (strategy.num_replicas_in_sync, global_batch_size))
        per_replica_batch_size = (global_batch_size //
                                  strategy.num_replicas_in_sync)
        per_replica_input_shapes = [(per_replica_batch_size, ) + s[1:]
                                    for s in input_shapes]

        # Create the models
        f32_model = self._create_model_from_layer(f32_layer, input_shapes)
        mp_model = self._create_model_from_layer(mp_layer, input_shapes)
        with strategy.scope():
            distributed_mp_model = self._create_model_from_layer(
                distributed_mp_layer, per_replica_input_shapes)

        # Set all model weights to the same values
        f32_weights = f32_model.get_weights()
        mp_model.set_weights(f32_weights)
        distributed_mp_model.set_weights(f32_weights)

        # Generate input data
        if input_data is None:
            # Cast inputs to float16 to avoid measuring error from having f16 layers
            # cast to float16.
            input_data = [
                np.random.normal(size=s).astype('float16')
                for s in input_shapes
            ]
            if len(input_data) == 1:
                input_data = input_data[0]

        # Assert all models have close outputs.
        f32_output = f32_model.predict(input_data)
        mp_output = mp_model.predict(input_data)
        self.assertAllClose(mp_output, f32_output, rtol=rtol, atol=atol)
        self.assertAllClose(distributed_mp_model.predict(input_data),
                            f32_output,
                            rtol=rtol,
                            atol=atol)

        # Run fit() on models
        output = np.random.normal(
            size=f32_model.outputs[0].shape).astype('float16')
        for model in f32_model, mp_model, distributed_mp_model:
            model.fit(input_data, output, batch_size=global_batch_size)

        # Assert all models have close weights
        f32_weights = f32_model.get_weights()
        self.assertAllClose(mp_model.get_weights(),
                            f32_weights,
                            rtol=rtol,
                            atol=atol)
        self.assertAllClose(distributed_mp_model.get_weights(),
                            f32_weights,
                            rtol=rtol,
                            atol=atol)
z = Lambda(sampling, output_shape=(latent_dim, ),
           name='z')([z_mean, z_log_var])

# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z, h, c], name='encoder')
encoder.summary()

# build decoder model
latent_inputs = Input(shape=(latent_dim, ), name='z')
latent_repeat = RepeatVector(maxlen)(latent_inputs)
h = Input(shape=(intermediate_dim, ), name='encoder_state_h')
c = Input(shape=(intermediate_dim, ), name='encoder_state_c')
x, _, _ = LSTM(intermediate_dim, return_sequences=True,
               return_state=True)(latent_repeat, initial_state=[h, c])
x, _, _ = LSTM(embed_dim, return_sequences=True, return_state=True)(x)
outputs = wrappers.TimeDistributed(Dense(embed_dim))(x)

# instantiate decoder model
decoder = Model([latent_inputs, h, c], outputs, name='decoder')
decoder.summary()

# instantiate VRAE model
outputs = decoder(encoder(inputs)[2:])
vrae = Model(inputs, outputs, name='vrae')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load h5 model trained weights"
    parser.add_argument("-w", "--weights", help=help_)

    args = parser.parse_args()