def test_sequential_get_mxnet_model_info_no_compile(): model = Sequential() model.add(Dense(2, input_shape=(3, ))) model.add(RepeatVector(3)) model.add(TimeDistributed(Dense(3))) with pytest.raises(AssertionError): K.get_mxnet_model_info(model)
def test_functional_model_get_mxnet_model_info(): inputs = Input(shape=(3, )) x = Dense(2)(inputs) outputs = Dense(3)(x) model = Model(inputs, outputs) model.compile(loss=losses.MSE, optimizer=optimizers.Adam(), metrics=[metrics.categorical_accuracy]) x = np.random.random((1, 3)) y = np.random.random((1, 3)) model.train_on_batch(x, y) data_names, data_shapes = K.get_mxnet_model_info(model) # Only one input assert len(data_names) == 1 # Example data_names = ['/dense_8_input1'] assert data_names[0].startswith('/input_') # Example data_shape = [DataDesc[/dense_8_input1,(1, 3),float32,NCHW]] assert len(data_shapes) == 1 assert data_shapes[0].name == data_names[0] # In this example, we are passing x as input with shape (1,3) assert data_shapes[0].shape == (1, 3)
def test_sequential_get_mxnet_model_info(): model = Sequential() model.add(Dense(2, input_shape=(3, ))) model.add(RepeatVector(3)) model.add(TimeDistributed(Dense(3))) model.compile(loss=losses.MSE, optimizer=optimizers.RMSprop(lr=0.0001), metrics=[metrics.categorical_accuracy], sample_weight_mode='temporal') x = np.random.random((1, 3)) y = np.random.random((1, 3, 3)) model.train_on_batch(x, y) data_names, data_shapes = K.get_mxnet_model_info(model) data_names_saved_model, data_shapes_saved_model = save_mxnet_model( model, prefix='test', epoch=0) # Only one input assert len(data_names) == 1 # Example data_names = ['/dense_8_input1'] assert data_names[0].startswith('/dense_') # Example data_shape = [DataDesc[/dense_8_input1,(1, 3),float32,NCHW]] assert len(data_shapes) == 1 assert data_shapes[0].name == data_names[0] # In this example, we are passing x as input with shape (1,3) assert data_shapes[0].shape == (1, 3) # Compare with returned values from Save MXNet Model API. They should be the same. assert len(data_names) == len(data_names_saved_model) assert data_names[0] == data_names_saved_model[0] assert len(data_shapes) == len(data_shapes_saved_model) assert data_shapes[0].name == data_shapes_saved_model[0].name assert data_shapes[0].shape == data_shapes_saved_model[0].shape os.remove('test-symbol.json') os.remove('test-0000.params')