def testTensorflowjsToKerasConversionSucceeds(self): with tf.Graph().as_default(), tf.Session(): sequential_model = keras.models.Sequential([ keras.layers.Dense(3, input_shape=(2, ), use_bias=True, kernel_initializer='ones', name='Dense1'), keras.layers.Dense(1, use_bias=False, kernel_initializer='ones', name='Dense2') ]) h5_path = os.path.join(self._tmp_dir, 'SequentialModel.h5') sequential_model.save(h5_path) converter.dispatch_keras_h5_to_tensorflowjs_conversion( h5_path, output_dir=self._tmp_dir) old_model_json = sequential_model.to_json() # Convert the tensorflowjs artifacts to a new H5 file. new_h5_path = os.path.join(self._tmp_dir, 'new.h5') converter.dispatch_tensorflowjs_to_keras_h5_conversion( os.path.join(self._tmp_dir, 'model.json'), new_h5_path) # Load the new H5 and compare the model JSONs. with tf.Graph().as_default(), tf.Session(): new_model = keras.models.load_model(new_h5_path) self.assertEqual(old_model_json, new_model.to_json())
def testTensorflowjsToKerasConversionFailsOnInvalidJsonFile(self): fake_json_path = os.path.join(self._tmp_dir, 'fake.json') with open(fake_json_path, 'wt') as f: f.write('__invalid_json_content__') with self.assertRaisesRegexp( # pylint: disable=deprecated-method ValueError, r'cannot read valid JSON content from'): converter.dispatch_tensorflowjs_to_keras_h5_conversion( fake_json_path, os.path.join(self._tmp_dir, 'model.h5'))
def testTensorflowjsToKerasConversionFailsOnExistingDirOutputPath(self): with tf.Graph().as_default(), tf.compat.v1.Session(): sequential_model = keras.models.Sequential([ keras.layers.Dense( 3, input_shape=(2,), use_bias=True, kernel_initializer='ones', name='Dense1'), keras.layers.Dense( 1, use_bias=False, kernel_initializer='ones', name='Dense2')]) h5_path = os.path.join(self._tmp_dir, 'SequentialModel.h5') sequential_model.save(h5_path) converter.dispatch_keras_h5_to_tfjs_layers_model_conversion( h5_path, output_dir=self._tmp_dir) with self.assertRaisesRegexp( # pylint: disable=deprecated-method ValueError, r'but received an existing directory'): converter.dispatch_tensorflowjs_to_keras_h5_conversion( os.path.join(self._tmp_dir, 'model.json'), self._tmp_dir)
def testTensorflowjsToKerasConversionFailsOnDirInputPath(self): with self.assertRaisesRegexp( # pylint: disable=deprecated-method ValueError, r'input path should be a model\.json file'): converter.dispatch_tensorflowjs_to_keras_h5_conversion( self._tmp_dir, os.path.join(self._tmp_dir, 'new.h5'))