Exemple #1
0
 def testGetParamsOverrideWithInitCheckpointPath(self):
     # Without override, default value is None.
     cfg = model_registry.GetParams('test.DummyModel', 'Train')
     self.assertIsNone(cfg.task.train.init_from_checkpoint_override)
     # Override ckpt path from empty to flag.
     FLAGS.model_params_override = (
         'task.train.init_from_checkpoint_override:/new/ckpt/path')
     cfg1 = model_registry.GetParams('test.DummyModel', 'Train')
     self.assertEqual(cfg1.task.train.init_from_checkpoint_override,
                      '/new/ckpt/path')
     # Unset checkpoint path.
     FLAGS.model_params_override = (
         'task.train.init_from_checkpoint_override:')
     cfg2 = model_registry.GetParams('test.DummyModelWithInitRules',
                                     'Train')
     self.assertEqual(cfg2.task.train.init_from_checkpoint_override, '')
  def testGetParamsCanOverrideWithFlags(self):
    cfg = model_registry.GetParams('test.DummyModel', 'Train')

    FLAGS.model_params_override = (
        'train.max_steps: 10;  train.ema_decay: 0.9\n'
        'train.init_from_checkpoint_rules : {"ckpt": (["abc", "def"], [])}\n')
    cfg2 = model_registry.GetParams('test.DummyModel', 'Train')

    self.assertNotEqual(cfg.train.max_steps, 10)
    self.assertEqual(cfg2.train.max_steps, 10)
    self.assertNotEqual(cfg.train.ema_decay, 0.9)
    self.assertEqual(cfg2.train.ema_decay, 0.9)
    self.assertNotEqual(cfg.train.init_from_checkpoint_rules,
                        {'ckpt': (['abc', 'def'], [])})
    self.assertEqual(cfg2.train.init_from_checkpoint_rules,
                     {'ckpt': (['abc', 'def'], [])})
Exemple #3
0
 def testExport(self):
     """Test basic export."""
     params = model_registry.GetParams('test.LinearModelParams', 'Test')
     inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
         params, subgraph_filter=['default'])
     self.assertIn('default', inference_graph.subgraphs)
     self.assertEqual(1, len(inference_graph.asset_file_def))
    def testEagerEMACheckpointCompatibility(self):
        self.assertTrue(tf.executing_eagerly())
        cfg = model_registry.GetParams('test.LinearModelParams', 'Train')
        # Use non-zero learning rate so that the weights are updated
        cfg.task.train.learner[0].learning_rate = 0.1
        cfg.task.train.learner[1].learning_rate = 0.1

        eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1')
        eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2')
        mdl = cfg.Instantiate()

        @tf.function
        def _Update():
            with py_utils.GradientTape(persistent=True):
                mdl.ConstructFPropBPropGraph()

        # Step 1
        _Update()
        # Save V1 checkpoints at step 1.
        ckpt_v1 = checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl)
        ckpt_v1.Save(gsteps=1)

        ema = mdl.ema
        model_to_ema_map = _GetModelEMAVariablePairs(mdl, ema)
        model_to_ema_map_snapshot_step1 = {
            k: v.value()
            for k, v in model_to_ema_map.items()
        }

        # Step 2
        _Update()
        # Save V2 checkpoints at step 2.
        ckpt_v2 = checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl)
        ckpt_v2.Save(gsteps=2)

        model_to_ema_map = _GetModelEMAVariablePairs(mdl, ema)
        model_to_ema_map_snapshot_step2 = {
            k: v.value()
            for k, v in model_to_ema_map.items()
        }

        with cluster_factory.SetEval(True):
            # Restores variables to values saved in `eager_v1_logdir`
            ckpt_v1.Restore()
        # Verify that the EMA variables from V1 checkpoints at step 1 successfully
        # overwrite the model variables.
        for v in mdl.variables:
            if v.ref() in model_to_ema_map_snapshot_step1:
                self.assertAllEqual(v,
                                    model_to_ema_map_snapshot_step1[v.ref()])

        with cluster_factory.SetEval(True):
            # Restores variables to values saved in `eager_v2_logdir`
            ckpt_v2.Restore()
        # Verify that the EMA variables from V2 checkpoints at step 2 successfully
        # overwrite the model variables.
        for v in mdl.variables:
            if v.ref() in model_to_ema_map_snapshot_step2:
                self.assertAllEqual(v,
                                    model_to_ema_map_snapshot_step2[v.ref()])
  def testEagerMultiLearnerCheckpointCompatibility(self):
    self.assertTrue(tf.executing_eagerly())
    cfg = model_registry.GetParams('test.LinearModelParams', 'Train')
    mdl = cfg.Instantiate()
    with py_utils.GradientTape(persistent=True):
      mdl.ConstructFPropBPropGraph()

    eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1')
    eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2')
    checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl).Save(gsteps=0)
    checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl).Save(gsteps=0)
    eager_v1_keys = _GetCheckpointKeys(
        os.path.join(eager_v1_logdir, 'ckpt_V1', 'ckpt-00000000'))
    eager_v2_keys = _GetCheckpointKeys(
        os.path.join(eager_v2_logdir, 'ckpt_V2', 'ckpt-0'))
    # Expecting two more variables in V2 checkpoints:
    # _CHECKPOINTABLE_OBJECT_GRAPH
    # save_counter
    self.assertEqual(len(eager_v1_keys) + 2, len(eager_v2_keys))  # pylint:disable=g-generic-assert

    py_utils.SetEagerMode(False)
    self.assertFalse(tf.executing_eagerly())
    graph_logdir = os.path.join(self.get_temp_dir(), 'graph')
    os.mkdir(graph_logdir)
    with self.session(graph=tf.Graph()) as sess:
      mdl = cfg.Instantiate()
      for lrn in mdl.GetTask().learners:
        lrn.optimizer.params.clear_variable_scope = False
      mdl.ConstructFPropBPropGraph()
      sess.run(tf.global_variables_initializer())
      checkpointer.Checkpointer(graph_logdir, mdl).Save(sess)
    graph_keys = _GetCheckpointKeys(os.path.join(graph_logdir, 'ckpt'))
    self.assertEqual(eager_v1_keys, graph_keys)
def get_model_params_as_text(model_path):
    try:
        cfg = model_registry.GetParams(model_path, "Train")
        return cfg.ToText()
    except LookupError:
        # Try reading as file.
        return tf.io.gfile.GFile(model_path).read()
    def testExportWithRandomSeeds(self):
        """Test the effect of setting random seeds on export."""
        params = model_registry.GetParams('test.LinearModelParams', 'Test')
        # Default -- use random_seed = None.
        inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
            params, subgraph_filter=['default'])
        pred = predictor.Predictor(inference_graph)
        [no_op_seed_1] = pred.Run(['output'], input=3)
        [no_op_seed_2] = pred.Run(['output'], input=3)
        self.assertNotEqual(no_op_seed_1, no_op_seed_2)
        pred = predictor.Predictor(inference_graph)
        [no_op_seed_3] = pred.Run(['output'], input=3)
        self.assertNotEqual(no_op_seed_1, no_op_seed_3)

        # Use a fixed random_seed.
        inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
            params, subgraph_filter=['default'], random_seed=1234)
        pred = predictor.Predictor(inference_graph)
        [fixed_op_seed_1] = pred.Run(['output'], input=3)
        [fixed_op_seed_2] = pred.Run(['output'], input=3)
        self.assertEqual(fixed_op_seed_1, fixed_op_seed_2)
        pred = predictor.Predictor(inference_graph)
        [fixed_op_seed_3] = pred.Run(['output'], input=3)
        self.assertEqual(fixed_op_seed_1, fixed_op_seed_3)

        # A different seed gives different results.
        inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
            params, subgraph_filter=['default'], random_seed=1235)
        pred = predictor.Predictor(inference_graph)
        [fixed_op_seed_4] = pred.Run(['output'], input=3)
        self.assertNotEqual(fixed_op_seed_1, fixed_op_seed_4)
Exemple #8
0
    def __init__(self, sess, batch_size):
        self.sess = sess
        self.batch_size = batch_size

        tf.set_random_seed(1234)
        params = model_registry.GetParams('asr.librispeech.Librispeech960Wpm',
                                          'Test')
        params.random_seed = 1234
        params.is_eval = True
        params.cluster.worker.gpus_per_replica = 1
        cluster = cluster_factory.Cluster(params.cluster)
        with cluster, tf.device(cluster.GetPlacer()):
            model = params.cls(params)

            # placeholders
            self.input_tf = tf.placeholder(tf.float32,
                                           shape=[batch_size, None],
                                           name='qq_input')
            self.tgt_tf = tf.placeholder(tf.string)
            self.sample_rate_tf = tf.placeholder(tf.int32,
                                                 name='qq_sample_rate')
            self.maxlen = tf.placeholder(np.int32)

            # generate the inputs that are needed for the lingvo model
            self.features = create_features(self.input_tf, self.sample_rate_tf)
            self.inputs = create_inputs(model, self.features, self.tgt_tf,
                                        self.batch_size)

            task = model.GetTask()
            metrics = task.FPropDefaultTheta(self.inputs)
            self.decoded = task.Decode(self.inputs)
Exemple #9
0
    def _GetSimpleTestConfig(self):
        model_name = 'image.mnist.LeNet5'
        cfg = model_registry.GetParams(model_name, 'Train')
        cfg.cluster.task = 0
        cfg.cluster.mode = 'sync'
        cfg.cluster.job = 'trainer_client'
        cfg.cluster.worker.name = '/job:localhost'
        cfg.cluster.worker.replicas = 1
        cfg.cluster.worker.gpus_per_replica = 0
        cfg.cluster.ps.name = '/job:localhost'
        cfg.cluster.ps.replicas = 1
        cfg.cluster.ps.gpus_per_replica = 0
        cfg.cluster.reporting_job = FLAGS.vizier_reporting_job

        # Generate 2 inputs.
        cfg.input.ckpt = FakeMnistData(self.get_temp_dir(),
                                       train_size=2,
                                       test_size=2)
        cfg.input.num_samples = 2
        cfg.input.batch_size = 2
        cfg.train.max_steps = 2
        ema_decay = 0.9999
        cfg.task.train.ema_decay = ema_decay
        cfg.train.ema_decay = ema_decay
        return cfg
Exemple #10
0
 def _TestGraph(self):
     params = model_registry.GetParams('test.LinearModelParams', 'Test')
     inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
         params)
     graph = tf.Graph()
     with graph.as_default():
         tf.import_graph_def(inference_graph.graph_def, name='')
     return graph, inference_graph
Exemple #11
0
 def testDecodeRuns(self):
     g = tf.Graph()
     with g.as_default():
         p = model_registry.GetParams('test.MnistV2', 'Test')
         model = p.Instantiate()
         task = model.GetTask()
         input_batch = task.GetInputBatch()[0]
         result = task.Decode(input_batch)
         self.assertIn('correct_top1', result)
Exemple #12
0
 def testExportFreezeDefault(self):
     """Test exporting frozen graph."""
     params = model_registry.GetParams('test.LinearModelParams', 'Test')
     inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
         params, freeze_defaults=True, subgraph_filter=['default'])
     self.assertIn('default', inference_graph.subgraphs)
     # Test graphs are well-formed and importable.
     with tf.Graph().as_default():
         tf.import_graph_def(inference_graph.graph_def)
 def testExport(self):
     """Test basic export."""
     params = model_registry.GetParams('test.LinearModelParams', 'Test')
     inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
         params, subgraph_filter=['default'], export_graph_collections=True)
     self.assertIn('default', inference_graph.subgraphs)
     self.assertEqual(1, len(inference_graph.asset_file_def))
     # Check the GLOBAL_VARIABLES graph collection which is needed for
     # eager to lift variables from a GraphDef.
     self.assertIn('variables', inference_graph.collection_def)
Exemple #14
0
    def testGetParams(self):
        cfg = model_registry.GetParams('test.DummyModel', 'Test')
        self.assertIsNotNone(cfg)
        self.assertEqual(DummyModel().Test(), cfg.input)
        cfg.input = None
        # Registered version adds model source info but direct does not.
        cfg.model = None
        self.assertEqual(DummyModel().Model(), cfg)

        cfg = model_registry.GetParams('test.DummyModel', 'Dataset')
        self.assertIsNotNone(cfg)
        self.assertEqual(DummyModel().Task_Dataset(), cfg.task)

        with self.assertRaises(LookupError):
            # Not yet registered.
            cfg = model_registry.GetParams('something.does.not.exist', 'Test')

        with self.assertRaises(base_model_params.DatasetError):
            cfg = model_registry.GetParams('test.DummyModel', 'UnknownDataset')
Exemple #15
0
 def testTpuBfloat16OverrideExport(self):
     """Test that we can export with tf.bfloat16 dtype."""
     params = model_registry.GetParams('test.LinearModelTpuParams', 'Test')
     inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
         params,
         subgraph_filter=['tpu'],
         device_options=inference_graph_exporter.InferenceDeviceOptions(
             device='tpu',
             retain_device_placement=True,
             var_options='ON_DEVICE',
             gen_init_op=True,
             dtype_override=tf.bfloat16))
     self.assertIn('tpu', inference_graph.subgraphs)
Exemple #16
0
    def testExportModelParamsWithInferenceGraph(self):
        params = model_registry.GetParams('test.DummyModelParams', 'Test')
        inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
            params)

        # Should populate subgraphs.
        self.assertIn('default', inference_graph.subgraphs)
        subgraph = inference_graph.subgraphs['default']
        self.assertIn('feed1', subgraph.feeds)
        self.assertIn('fetch1', subgraph.fetches)

        self.assertEqual(subgraph.feeds['feed1'], 'inference/feed1_node:0')
        self.assertEqual(subgraph.fetches['fetch1'], 'inference/fetch1_node:0')
        self.assertEqual(subgraph.fetches['fetch_op'], 'inference/fetch1_node')
 def testExportModelDoesNotAffectFlagsOnException(self):
   initial_flags = {k: tf.flags.FLAGS[k].value for k in tf.flags.FLAGS}
   params = model_registry.GetParams('test.DummyLegacyModelParams', 'Test')
   with self.assertRaises(NotImplementedError):
     inference_graph_exporter.InferenceGraphExporter.Export(
         params,
         device_options=inference_graph_exporter.InferenceDeviceOptions(
             device='tpu',
             retain_device_placement=False,
             var_options=None,
             gen_init_op=True,
             dtype_override=None))
   self.assertDictEqual(initial_flags,
                        {k: tf.flags.FLAGS[k].value for k in tf.flags.FLAGS})
Exemple #18
0
 def testMnistLeNet5(self):
   g = tf.Graph()
   with g.as_default():
     tf.set_random_seed(1618)
     p = model_registry.GetParams('image.mnist.LeNet5', 'Test')
     p.random_seed = 73234288
     p.input.ckpt = self.data_path
     p.task.params_init = py_utils.WeightInit.Uniform(0.1, seed=73234288)
     with cluster_factory.ForTestingWorker(mode='sync', job='trainer_client'):
       model = p.cls(p)
       model.ConstructFPropBPropGraph()
   with self.session(graph=g) as sess:
     sess.run(tf.global_variables_initializer())
     CompareToGoldenSingleFloat(self, 2.302583, self._runOneStep(model, sess))
     CompareToGoldenSingleFloat(self, 2.302405, self._runOneStep(model, sess))
Exemple #19
0
 def _get_model(self):
     """Fetch model for decoding."""
     name = FLAGS.model
     self.is_transformer = "Transformer" in FLAGS.model
     p = model_registry.GetParams(
         "feature_neighborhood_model_config." + name, "Train")
     p.is_inference = True
     p.input.file_pattern = FLAGS.feature_neighborhood_test_path
     # Send it round twice so we can get the last few examples.
     p.input.repeat_count = 2
     p.cluster.require_sequential_input_order = True
     if self.is_transformer:
         p.task.beam_size = FLAGS.beam_size
     mdl = p.Instantiate()
     mdl.ConstructFPropGraph()
     return mdl.GetTask()
Exemple #20
0
 def testMnistV2(self):
     g = tf.Graph()
     with g.as_default():
         tf.random.set_seed(1618)
         p = model_registry.GetParams('test.MnistV2', 'Test')
         p.random_seed = 73234288
         p.input.ckpt = self.data_path
         p.task.params_init = py_utils.WeightInit.Uniform(0.1,
                                                          seed=73234288)
         with cluster_factory.ForTestingWorker(mode='sync',
                                               job='trainer_client'):
             model = p.Instantiate()
             model.ConstructFPropBPropGraph()
     with self.session(graph=g):
         self.evaluate(tf.global_variables_initializer())
         CompareToGoldenSingleFloat(self, 2.303070, self._runOneStep(model))
         CompareToGoldenSingleFloat(self, 2.297364, self._runOneStep(model))
Exemple #21
0
    def _GetTestConfig(self):
        model_name = 'image.mnist.LeNet5'
        cfg = model_registry.GetParams(model_name, 'Dev')
        cfg.cluster.task = 0
        cfg.cluster.mode = 'sync'
        cfg.cluster.job = 'trainer_client'
        cfg.cluster.worker.name = '/job:local'
        cfg.cluster.worker.replicas = 1
        cfg.cluster.worker.gpus_per_replica = 0
        cfg.cluster.ps.name = '/job:local'
        cfg.cluster.ps.replicas = 1
        cfg.cluster.ps.gpus_per_replica = 0

        # Generate 2 inputs.
        self._tmpdir, cfg.input.ckpt = FakeMnistData(train_size=0, test_size=2)
        cfg.input.num_samples = 2
        cfg.train.max_steps = 2
        cfg.train.ema_decay = 0.9999
        return cfg
Exemple #22
0
def main(unused_argv):
    override_flags()
    if FLAGS.disable_logging:
        tf.get_logger().setLevel('CRITICAL')

    model_params = model_registry.GetParams(FLAGS.model, None)
    tf.logging.info('Found model %s', FLAGS.model)

    batch_size = model_params.task.batch_size

    decoder = GShardLMDecodeBatch()
    decoder.init_vocab(model_params)
    decoder.preload_data(batch_size)

    decoder.reset_tpu_cluster()
    decoder.reset_session()

    try:
        decoder.init_graph(model_params)
        decoder.run_init_sequence()
        decoder.batch_decode()
    finally:
        decoder.reset_tpu_cluster()
    def _create_graph(self):
        if self._sess is not None:
            return

        with self._cluster:
            cfg = model_registry.GetParams(self._model_name, self._split)
            cfg.input.batch_size = 1
            # Turn off label filtering so the database contains
            # all objects.
            cfg.input.extractors.labels.filter_labels = None

            # Disable preprocessors if they are not required.
            if not self._run_preprocessors:
                cfg.input.preprocessors_order = []

            graph = tf.Graph()
            with graph.as_default():
                inp = cfg.input.Instantiate()
                self._elem = tf.placeholder(tf.string)
                bucket, batch = inp.ExtractUsingExtractors(self._elem)
                self._filtered_data = _GetFilteredBoundingBoxData(batch)
                self._bucket = bucket
        self._sess = tf.Session(graph=graph)
Exemple #24
0
    def _GetTestConfig(self):
        model_name = 'image.mnist.LeNet5'
        # So this is how particular parameters are obtained
        cfg = model_registry.GetParams(model_name, 'Train')
        cfg.cluster.task = 0
        cfg.cluster.mode = 'sync'
        cfg.cluster.job = 'trainer_client'
        cfg.cluster.worker.name = '/job:localhost'
        cfg.cluster.worker.replicas = 1
        cfg.cluster.worker.gpus_per_replica = 0
        cfg.cluster.ps.name = '/job:localhost'
        cfg.cluster.ps.replicas = 1
        cfg.cluster.ps.gpus_per_replica = 0

        # Generate 2 inputs.
        cfg.input.ckpt = FakeMnistData(self.get_temp_dir(),
                                       train_size=2,
                                       test_size=2)
        cfg.input.num_samples = 2
        cfg.input.batch_size = 2
        cfg.train.max_steps = 2
        cfg.train.ema_decay = 0.9999
        return cfg
Exemple #25
0
  def testOverrideVarsFromCheckpointWithIgnoreRules(self):

    with self.session(use_gpu=False) as sess:
      tf.set_random_seed(8372749040)
      cfg = model_registry.GetParams('image.mnist.LeNet5', 'Train')
      with cluster_factory.ForTestingWorker(mode='sync', job='trainer_client'):
        cfg.cls(cfg)
      tf.global_variables_initializer().run()
      self.assertAllClose(
          # These are initialized values before overriding with checkpoint.
          self._GetLeNetVarsFirstVal(sess),
          [-0.005945, -0.036722, 0.0])
      checkpoint_path = test_helper.test_src_dir_path(
          'core/testdata/lenet_test_model')
      variable_loading_rules = [('lenet5/conv0/w/var', 'lenet5/conv0/w/var'),
                                ('lenet5/conv1/w/var', 'lenet5/conv1/w/var')]
      variable_ignore_rules = ['lenet5/conv1/w/var']
      py_utils._OverrideVarsFromCheckpoint(
          sess, tf.all_variables(), checkpoint_path, variable_loading_rules,
          variable_ignore_rules)
      self.assertAllClose(
          # Now only conv0 weights have been overridden.
          self._GetLeNetVarsFirstVal(sess),
          [0.043092, -0.036722, 0.0])
Exemple #26
0
 def testInference(self):
     with self.session() as sess:
         tf.random.set_seed(1618)
         p = model_registry.GetParams('test.MnistV2', 'Test')
         p.random_seed = 73234288
         p.input.ckpt = self.data_path
         p.task.params_init = py_utils.WeightInit.Uniform(0.1,
                                                          seed=73234288)
         model = p.Instantiate()
         subgraphs = model.GetTask().Inference()
         self.assertCountEqual(['default'], list(subgraphs.keys()))
         fetches, feeds = subgraphs['default']
         self.assertCountEqual(['normalized_image'], list(feeds.keys()))
         self.assertCountEqual(['logits', 'probs', 'prediction'],
                               list(fetches.keys()))
         self.evaluate(tf.global_variables_initializer())
         fetch_results = sess.run(
             fetches,
             {feeds['normalized_image']: np.zeros(p.input.data_shape)})
         self.assertAllEqual([p.task.softmax.num_classes],
                             fetch_results['logits'].shape)
         self.assertAllEqual([p.task.softmax.num_classes],
                             fetch_results['probs'].shape)
         self.assertAllEqual([], fetch_results['prediction'].shape)
Exemple #27
0
 def testGetModelParams(self):
     p = model_registry.GetParams('test.DummyModel', 'Train')
     self.assertTrue(issubclass(p.cls, base_model.SingleTaskModel))
    def __init__(self,
                 sess,
                 batch_size=1,
                 lr_step1=100,
                 lr_step2=0.1,
                 num_iter_step1=1000,
                 num_iter_step2=4000,
                 th=None,
                 psd_max_ori=None):

        self.sess = sess
        self.num_iter_step1 = num_iter_step1
        self.num_iter_step2 = num_iter_step2
        self.batch_size = batch_size
        self.lr_step1 = lr_step1
        #self.lr_step2 = lr_step2

        tf.set_random_seed(1234)
        params = model_registry.GetParams('asr.librispeech.Librispeech960Wpm',
                                          'Test')
        params.random_seed = 1234
        params.is_eval = True
        params.cluster.worker.gpus_per_replica = 1
        cluster = cluster_factory.Cluster(params.cluster)
        with cluster, tf.device(cluster.GetPlacer()):
            model = params.cls(params)
            self.delta_large = tf.Variable(np.zeros((batch_size, 223200),
                                                    dtype=np.float32),
                                           name='qq_delta')

            # placeholders
            self.input_tf = tf.placeholder(tf.float32,
                                           shape=[batch_size, None],
                                           name='qq_input')
            self.tgt_tf = tf.placeholder(tf.string)
            self.sample_rate_tf = tf.placeholder(tf.int32,
                                                 name='qq_sample_rate')
            self.th = tf.placeholder(tf.float32,
                                     shape=[batch_size, None, None],
                                     name='qq_th')
            self.psd_max_ori = tf.placeholder(tf.float32,
                                              shape=[batch_size],
                                              name='qq_psd')
            self.mask = tf.placeholder(dtype=np.float32,
                                       shape=[batch_size, None],
                                       name='qq_mask')
            self.mask_freq = tf.placeholder(dtype=np.float32,
                                            shape=[batch_size, None, 80])
            #noise = tf.random_normal(self.new_input.shape, stddev=2)
            self.noise = tf.placeholder(np.float32,
                                        shape=[batch_size, None],
                                        name="qq_noise")
            self.maxlen = tf.placeholder(np.int32)
            self.lr_step2 = tf.placeholder(np.float32)

            # variable
            self.rescale = tf.Variable(np.ones((batch_size, 1),
                                               dtype=np.float32),
                                       name='qq_rescale')
            self.alpha = tf.Variable(np.ones(
                (batch_size), dtype=np.float32) * 0.05,
                                     name='qq_alpha')

            # extract the delta
            self.delta = tf.slice(tf.identity(self.delta_large), [0, 0],
                                  [batch_size, self.maxlen])
            self.apply_delta = tf.clip_by_value(self.delta, -2000,
                                                2000) * self.rescale
            self.new_input = self.apply_delta * self.mask + self.input_tf
            #pass_in = tf.clip_by_value(self.new_input, -2**15, 2**15-1)
            self.pass_in = tf.clip_by_value(self.new_input + self.noise,
                                            -2**15, 2**15 - 1)

            # generate the inputs that are needed for the lingvo model
            self.features = create_features(self.pass_in, self.sample_rate_tf,
                                            self.mask_freq)
            self.inputs = create_inputs(model, self.features, self.tgt_tf,
                                        self.batch_size, self.mask_freq)

            task = model.GetTask()
            metrics = task.FPropDefaultTheta(self.inputs)
            # self.celoss with the shape (batch_size)
            self.celoss = tf.get_collection("per_loss")[0]
            self.decoded = task.Decode(self.inputs)

        # compute the loss for masking threshold
        self.loss_th_list = []
        self.transform = Transform()
        for i in range(self.batch_size):
            logits_delta = self.transform((self.apply_delta[i, :]),
                                          (self.psd_max_ori)[i])
            loss_th = tf.reduce_mean(tf.nn.relu(logits_delta - (self.th)[i]))
            loss_th = tf.expand_dims(loss_th, dim=0)
            self.loss_th_list.append(loss_th)
        self.loss_th = tf.concat(self.loss_th_list, axis=0)

        self.optimizer1 = tf.train.AdamOptimizer(self.lr_step1)
        self.optimizer2 = tf.train.AdamOptimizer(self.lr_step2)

        grad1, var1 = self.optimizer1.compute_gradients(
            self.celoss, [self.delta_large])[0]
        grad21, var21 = self.optimizer2.compute_gradients(
            self.celoss, [self.delta_large])[0]
        grad22, var22 = self.optimizer2.compute_gradients(
            self.alpha * self.loss_th, [self.delta_large])[0]

        self.train1 = self.optimizer1.apply_gradients([(tf.sign(grad1), var1)])
        self.train21 = self.optimizer2.apply_gradients([(grad21, var21)])
        self.train22 = self.optimizer2.apply_gradients([(grad22, var22)])
        self.train2 = tf.group(self.train21, self.train22)
Exemple #29
0
def main(argv):
    data = np.loadtxt(FLAGS.input, dtype=str, delimiter=",")
    # calculate the number of loops to run the test
    num = len(data[0])
    batch_size = FLAGS.batch_size
    num_loops = num / batch_size
    assert num % batch_size == 0

    with tf.device("/gpu:0"):
        tf.set_random_seed(1234)
        tfconf = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=tfconf) as sess:
            params = model_registry.GetParams(
                'asr.librispeech.Librispeech960Wpm', 'Test')
            params.cluster.worker.gpus_per_replica = 1
            cluster = cluster_factory.Cluster(params.cluster)
            with cluster, tf.device(cluster.GetPlacer()):
                params.vn.global_vn = False
                params.random_seed = 1234
                params.is_eval = True
                model = params.cls(params)
                task = model.GetTask()
                saver = tf.train.Saver()
                saver.restore(sess, FLAGS.checkpoint)

                # define the placeholders
                input_tf = tf.placeholder(tf.float32, shape=[batch_size, None])
                tgt_tf = tf.placeholder(tf.string)
                sample_rate_tf = tf.placeholder(tf.int32)
                mask_tf = tf.placeholder(tf.float32,
                                         shape=[batch_size, None, 80])

                # generate the features and inputs
                features = create_features(input_tf, sample_rate_tf, mask_tf)
                shape = tf.shape(features)
                inputs = create_inputs(model, features, tgt_tf, batch_size,
                                       mask_tf)

                # loss
                metrics = task.FPropDefaultTheta(inputs)
                loss = tf.get_collection("per_loss")[0]

                # prediction
                decoded_outputs = task.Decode(inputs)
                dec_metrics_dict = task.CreateDecoderMetrics()

                correct = 0
                for l in range(num_loops):
                    data_sub = data[:, l * batch_size:(l + 1) * batch_size]
                    audios_np, sample_rate, tgt_np, mask_freq = Read_input(
                        data_sub, batch_size)
                    feed_dict = {
                        input_tf: audios_np,
                        sample_rate_tf: sample_rate,
                        tgt_tf: tgt_np,
                        mask_tf: mask_freq
                    }

                    losses = sess.run(loss, feed_dict)
                    predictions = sess.run(decoded_outputs, feed_dict)

                    task.PostProcessDecodeOut(predictions, dec_metrics_dict)
                    wer_value = dec_metrics_dict['wer'].value * 100.

                    for i in range(batch_size):
                        print("pred:{}".format(predictions['topk_decoded'][i,
                                                                           0]))
                        print("targ:{}".format(tgt_np[i].lower()))
                        print("true: {}".format(data_sub[1, i].lower()))

                        if predictions['topk_decoded'][i,
                                                       0] == tgt_np[i].lower():
                            correct += 1
                            print("------------------------------")
                            print("example {} succeeds".format(i))

                    print("Now, the WER is: {0:.2f}%".format(wer_value))
                print("num of examples succeed: {}".format(correct))
                print("success rate: {}%".format(correct / float(num) * 100))
Exemple #30
0
    def __init__(
        self,
        sess,
        batch_size=1,
        lr_stage1=100,
        lr_stage2=0.1,
        num_iter_stage1=1000,
        num_iter_stage2=4000,
        th=None,
        psd_max_ori=None,
    ):

        self.sess = sess
        self.num_iter_stage1 = num_iter_stage1
        self.num_iter_stage2 = num_iter_stage2
        self.batch_size = batch_size
        self.lr_stage1 = lr_stage1
        self.lr_stage2 = lr_stage2

        tf.set_random_seed(1234)
        params = model_registry.GetParams("asr.librispeech.Librispeech960Wpm",
                                          "Test")
        params.random_seed = 1234
        params.is_eval = True
        params.cluster.worker.gpus_per_replica = 1
        cluster = cluster_factory.Cluster(params.cluster)
        with cluster, tf.device(cluster.GetPlacer()):
            model = params.cls(params)
            self.delta_large = tf.Variable(
                np.zeros((batch_size, FLAGS.max_length_dataset),
                         dtype=np.float32),
                name="qq_delta",
            )

            # placeholders
            self.input_tf = tf.placeholder(tf.float32,
                                           shape=[batch_size, None],
                                           name="qq_input")
            self.tgt_tf = tf.placeholder(tf.string)
            self.rir = tf.placeholder(tf.float32)

            self.sample_rate_tf = tf.placeholder(tf.int32,
                                                 name="qq_sample_rate")
            self.mask = tf.placeholder(dtype=np.float32,
                                       shape=[batch_size, None],
                                       name="qq_mask")
            self.mask_freq = tf.placeholder(dtype=np.float32,
                                            shape=[batch_size, None, 80])
            self.noise = tf.placeholder(np.float32,
                                        shape=[batch_size, None],
                                        name="qq_noise")
            self.maxlen = tf.placeholder(np.int32)
            self.lr = tf.placeholder(np.float32)
            self.lengths = tf.placeholder(
                np.int32,
                shape=[
                    batch_size,
                ],
            )

            # variable
            self.rescale = tf.Variable(
                np.ones(
                    (batch_size, 1), dtype=np.float32) * FLAGS.initial_bound,
                name="qq_rescale",
            )

            # extract the delta
            self.delta = tf.slice(tf.identity(self.delta_large), [0, 0],
                                  [batch_size, self.maxlen])
            self.apply_delta = tf.clip_by_value(self.delta, -self.rescale,
                                                self.rescale)
            self.before_rir = tf.clip_by_value(
                self.apply_delta * self.mask + self.input_tf, -(2**15),
                2**15 - 1)
            self.new_input = (create_speech_rir(
                self.before_rir,
                self.rir,
                self.lengths,
                self.maxlen,
                self.batch_size,
            ) * self.mask)
            self.pass_in = tf.clip_by_value(self.new_input + self.noise,
                                            -(2**15), 2**15 - 1)

            # generate the inputs that are needed for the lingvo model
            self.features = create_features(self.pass_in, self.sample_rate_tf,
                                            self.mask_freq)
            self.inputs = create_inputs(model, self.features, self.tgt_tf,
                                        self.batch_size, self.mask_freq)

            task = model.GetTask()
            metrics = task.FPropDefaultTheta(self.inputs)

            # self.celoss with the shape (batch_size)
            self.celoss = tf.get_collection("per_loss")[0]
            self.decoded = task.Decode(self.inputs)

        self.optimizer1 = tf.train.AdamOptimizer(self.lr)
        grad1, var1 = self.optimizer1.compute_gradients(
            self.celoss, [self.delta_large])[0]
        self.train1 = self.optimizer1.apply_gradients([(tf.sign(grad1), var1)])