예제 #1
0
def test_sequential_get_mxnet_model_info_no_compile():
    model = Sequential()
    model.add(Dense(2, input_shape=(3, )))
    model.add(RepeatVector(3))
    model.add(TimeDistributed(Dense(3)))
    with pytest.raises(AssertionError):
        K.get_mxnet_model_info(model)
예제 #2
0
def test_functional_model_get_mxnet_model_info():
    inputs = Input(shape=(3, ))
    x = Dense(2)(inputs)
    outputs = Dense(3)(x)

    model = Model(inputs, outputs)
    model.compile(loss=losses.MSE,
                  optimizer=optimizers.Adam(),
                  metrics=[metrics.categorical_accuracy])
    x = np.random.random((1, 3))
    y = np.random.random((1, 3))
    model.train_on_batch(x, y)

    data_names, data_shapes = K.get_mxnet_model_info(model)

    # Only one input
    assert len(data_names) == 1
    # Example data_names = ['/dense_8_input1']
    assert data_names[0].startswith('/input_')

    # Example data_shape = [DataDesc[/dense_8_input1,(1, 3),float32,NCHW]]
    assert len(data_shapes) == 1
    assert data_shapes[0].name == data_names[0]
    # In this example, we are passing x as input with shape (1,3)
    assert data_shapes[0].shape == (1, 3)
예제 #3
0
def test_sequential_get_mxnet_model_info():
    model = Sequential()
    model.add(Dense(2, input_shape=(3, )))
    model.add(RepeatVector(3))
    model.add(TimeDistributed(Dense(3)))
    model.compile(loss=losses.MSE,
                  optimizer=optimizers.RMSprop(lr=0.0001),
                  metrics=[metrics.categorical_accuracy],
                  sample_weight_mode='temporal')
    x = np.random.random((1, 3))
    y = np.random.random((1, 3, 3))
    model.train_on_batch(x, y)

    data_names, data_shapes = K.get_mxnet_model_info(model)

    data_names_saved_model, data_shapes_saved_model = save_mxnet_model(
        model, prefix='test', epoch=0)

    # Only one input
    assert len(data_names) == 1
    # Example data_names = ['/dense_8_input1']
    assert data_names[0].startswith('/dense_')

    # Example data_shape = [DataDesc[/dense_8_input1,(1, 3),float32,NCHW]]
    assert len(data_shapes) == 1
    assert data_shapes[0].name == data_names[0]
    # In this example, we are passing x as input with shape (1,3)
    assert data_shapes[0].shape == (1, 3)

    # Compare with returned values from Save MXNet Model API. They should be the same.
    assert len(data_names) == len(data_names_saved_model)
    assert data_names[0] == data_names_saved_model[0]

    assert len(data_shapes) == len(data_shapes_saved_model)
    assert data_shapes[0].name == data_shapes_saved_model[0].name
    assert data_shapes[0].shape == data_shapes_saved_model[0].shape

    os.remove('test-symbol.json')
    os.remove('test-0000.params')