Esempio n. 1
0
 def testSharedEncBiasWeights(self):
     model_dim = 4
     key_value_dim = 2
     num_heads = 2
     g = tf.Graph()
     with g.as_default(), self.SetEval(True):
         _ = py_utils.GetOrCreateGlobalStepVar()  # for DeterministicDropout
         builder = FakeMoEBuilder.Params().Set(
             num_devices=FLAGS.num_partitions,
             dropout_rate=0,
             model_dim=model_dim,
             attention_key_value_dim=key_value_dim,
             attention_num_heads=num_heads)
         builder = builder.Instantiate()
         p = builder._Seq('model', builder.FakeLayer('layer0'),
                          builder.FakeLayer('layer1'))
         layer = p.Instantiate()
         all_vars = tf.trainable_variables()
         tf.logging.info(all_vars)
         self.assertEqual(1, len(all_vars))
     with tf.Session(graph=g) as sess, self.SetEval(True):
         x = tf.ones([model_dim])
         y = layer.FPropDefaultTheta(x)
         sess.run(tf.global_variables_initializer())
         y_val = sess.run(y)
         self.assertAllEqual([3.] * model_dim, y_val)
Esempio n. 2
0
def ReadNpArrays(file_prefix, nmap):
  """Reads from a tf checkpoint to fill in values of a NesteMap.

  Args:
    file_prefix: A TF checkpoint filename prefix.
    nmap: A NestedMap of numpy dtypes.

  Returns:
    A NestedMap with numpy arrays compatible w/ nmap.
  """
  g = tf.Graph()
  with g.as_default():
    reads = []
    for name, dtype in nmap.FlattenItems():
      reads.append(
          io_ops.restore_v2(
              prefix=file_prefix,
              tensor_names=[name],
              shape_and_slices=[""],
              dtypes=[dtype])[0])

  with tf.Session(graph=g) as sess:
    vals = sess.run(reads)

  return nmap.Pack(vals)
Esempio n. 3
0
  def _CreateNewSession(self):
    """Updates self._sess with a new session."""
    config = self._session_config
    if not config:
      config = py_utils.SessionConfig()
    sess = tf.Session(self._tf_master, graph=self._graph, config=config)

    try:
      sess.run(self._graph.get_operation_by_name("init_all_tables"))
    except KeyError:
      tf.logging.info("Could not find tables initializer in graph.")
    if self._device_type == "tpu":
      sess.run(self._graph.get_operation_by_name("tpu_init_op"))
    if self._checkpoint:
      self._saver.restore(sess, self._checkpoint)
    else:
      try:
        init_op = self._graph.get_operation_by_name("init_all_variables")
        sess.run(init_op)
      except KeyError:
        tf.logging.warning(
            "No checkpoint provided and the graph has no default "
            "variable_init op.")
    tf.logging.info("Created new predictor session.")
    self._sess = sess
Esempio n. 4
0
 def testEmpty(self):
     with tf.Session():
         ids, strs = self._enc.Encode('')
         self.assertEqual(
             b'',
             tf.strings.reduce_join(strs, separator=' ').eval())
         self.assertEqual(b'', self._enc.Decode(ids).eval())
Esempio n. 5
0
 def testWithBackslash(self):
     with tf.Session():
         ids, strs = self._enc.Encode('\\')
         self.assertEqual(
             u'▁ \\'.encode('utf-8'),
             tf.strings.reduce_join(strs, separator=' ').eval())
         self.assertEqual(b'\\', self._enc.Decode(ids).eval())
Esempio n. 6
0
def WriteNpArrays(file_prefix, nmap):
  """Writes a NestedMap of numpy arrays into a TF checkpoint.

  Args:
    file_prefix: A TF checkpoint filename prefix.
    nmap: A NestedMap of numpy arrays.
  """
  g = tf.Graph()
  with g.as_default():

    def Wrap(val):
      dtype = tf.as_dtype(val.dtype)
      assert dtype != tf.string  # tf.string is not supported by py_func.
      return tf.py_func(lambda: val, [], dtype)

    names, values = [], []
    for k, v in nmap.FlattenItems():
      names.append(k)
      assert isinstance(v, np.ndarray)
      values.append(Wrap(v))

    save = io_ops.save_v2(
        prefix=file_prefix,
        tensor_names=names,
        tensors=values,
        shape_and_slices=[""] * len(names))

  with tf.Session(graph=g) as sess:
    sess.run(save)
Esempio n. 7
0
    def testCurriculumDataSourceTransitionsCorrectlyWithMixingDataSource(self):
        sources = [
            datasource.WithinBatchMixingDataSource.Params().Set(
                file_patterns=['file1', 'file2'], weights=[1, 5]),
            datasource.WithinBatchMixingDataSource.Params().Set(
                file_patterns=['file3', 'file4'], weights=[2, 3])
        ]
        boundary = 5
        ds_params = datasource.CurriculumDataSource.Params().Set(
            datasource_params=sources, boundaries=[boundary])
        ds = ds_params.Instantiate()

        ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)
        with tf.Session() as sess:
            # Advance the global step to the next curriculum stage
            global_step = py_utils.GetOrCreateGlobalStepVar()
            tf.global_variables_initializer().run()
            set_global_step = tf.assign(global_step,
                                        boundary,
                                        name='advance_step')
            sess.run(set_global_step)

            ret.data = sess.run([ret.data])

        self.assertCountEqual(sorted(ret.keys()),
                              ['bprop_variable_filters', 'data'])
        self.assertAllEqual(ret.data, [[b'file3,file4']])
        self.assertCountEqual(ret.bprop_variable_filters, [''])
Esempio n. 8
0
  def testPoolingWithUnknowShapeInput(self):
    """Tests GlobalPooling layer with unknown shape tensor."""

    def remove_shape(tensor):
      shape = tf.placeholder(tf.int32, name='removed_shape')
      return tf.reshape(tensor, shape)

    g = tf.Graph()
    with g.as_default(), tf.Session(graph=g) as _:
      tf.random.set_seed(24332)
      input_shape = [3, 5, 2, 4]
      inputs = np.random.random(input_shape) - 0.5
      expected_avg_output = np.mean(inputs, axis=(1, 2), keepdims=True)
      input_tensor = tf.convert_to_tensor(inputs, dtype=tf.float32)
      # initial shape is [3, 5, 2, 4]
      self.assertEqual(py_utils.GetShape(input_tensor), input_shape)
      # remove shape using a tf Defun and verify dynamic tensor shape.
      input_tensor = remove_shape(input_tensor)
      self.assertIsInstance(py_utils.GetShape(input_tensor), tf.Tensor)
      self.assertIsNone(input_tensor.shape.rank)
      self._testHelper(
          'AVG',
          input_tensor,
          None,
          expected_avg_output,
          None,
          feed_dict={'removed_shape:0': input_shape})
Esempio n. 9
0
def main(argv):
    del argv  # Unused.

    dataset = getattr(tf.keras.datasets, FLAGS.dataset)
    (x_train, y_train), (x_test, y_test) = dataset.load_data()

    def wrap(val):
        dtype = tf.as_dtype(val.dtype)
        assert dtype != tf.string  # tf.string is not supported by py_func.
        return tf.py_func(lambda: val, [], dtype)

    out_prefix = FLAGS.out or os.path.join("/tmp", FLAGS.dataset,
                                           FLAGS.dataset)
    tf.logging.info("Save %s dataset to %s ckpt." %
                    (FLAGS.dataset, out_prefix))

    with tf.Session() as sess:
        sess.run(
            io_ops.save_v2(
                prefix=out_prefix,
                tensor_names=["x_train", "y_train", "x_test", "y_test"],
                shape_and_slices=[""] * 4,
                tensors=[
                    wrap(x_train),
                    wrap(y_train),
                    wrap(x_test),
                    wrap(y_test)
                ]))
Esempio n. 10
0
 def _GetSession(self, **kwargs):
     if py_utils.IsEagerMode():
         raise ValueError('_GetSession is not supported in eager mode.')
     graph = kwargs.pop('graph', self._graph)
     return tf.Session(self._tf_master,
                       graph=graph,
                       config=py_utils.SessionConfig(**kwargs))
Esempio n. 11
0
    def testCrossBatchMixingDataSourceSucceedsWithListFilesAndWeights(self):
        files = ['path_to_file', 'path_to_file']
        datasources = [
            datasource.SimpleDataSource.Params().Set(file_pattern=f)
            for f in files
        ]
        weights = [1, 4]
        ds_params = datasource.CrossBatchMixingDataSource.Params().Set(
            sub=datasources, weights=weights)
        ds = ds_params.Instantiate()
        ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)

        with tf.Session():
            ret.data = self.evaluate([ret.data])

        self.assertCountEqual(sorted(ret.keys()), [
            'bprop_variable_filters', 'data', 'selected_bprop',
            'source_selected'
        ])
        # CrossBatchMixing operates on the python side of the tf op so a single
        # element will be returned by _MockDataSourceFromFilePattern
        self.assertAllEqual(ret.data, [[b'path_to_file']])
        self.assertCountEqual(ret.bprop_variable_filters, [''] * len(files))
        self.assertAllEqual(ret.selected_bprop.shape, [2])
        self.assertAllEqual(ret.source_selected.shape, [1, 2])
Esempio n. 12
0
    def testCurriculumDataSourceTransitionsCorrectlyWithSimpleDataSource(self):
        sources = [
            datasource.SimpleDataSource.Params().Set(file_pattern='file1'),
            datasource.SimpleDataSource.Params().Set(file_pattern='file2'),
        ]
        boundary = 5
        ds_params = datasource.CurriculumDataSource.Params().Set(
            sub=sources, boundaries=[boundary])
        ds = ds_params.Instantiate()

        ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)
        with tf.Session():
            # Advance the global step to the next curriculum stage
            global_step = py_utils.GetOrCreateGlobalStepVar()
            self.evaluate(tf.global_variables_initializer())
            set_global_step = tf.assign(global_step,
                                        boundary,
                                        name='advance_step')
            self.evaluate(set_global_step)

            ret.data = self.evaluate([ret.data])

        self.assertCountEqual(sorted(ret.keys()),
                              ['bprop_variable_filters', 'data'])
        self.assertAllEqual(ret.data, [[b'file2']])
        self.assertCountEqual(ret.bprop_variable_filters, [''])
Esempio n. 13
0
 def testWithUnicode(self):
     with tf.Session():
         ids, strs = self._enc.Encode('føö')
         self.assertEqual(
             u'▁ f ø ö'.encode('utf-8'),
             tf.strings.reduce_join(strs, separator=' ').eval())
         self.assertEqual(u'føö'.encode('utf-8'),
                          self._enc.Decode(ids).eval())
Esempio n. 14
0
 def testMergeProb(self):
   voc = self._CreateVocab()
   enc = wpm_encoder.WpmEncoder(voc, merge_prob=0.)
   with tf.Session():
     ids, strs = enc.Encode('Ditto')
     self.assertEqual(u'▁ D i t t o'.encode('utf-8'),
                      tf.strings.reduce_join(strs, separator=' ').eval())
     self.assertEqual(b'Ditto', self._enc.Decode(ids).eval())
Esempio n. 15
0
    def testSimpleDataSourceSucceedsWithFileType(self):
        ds_params = datasource.SimpleDataSource.Params().Set(
            file_pattern='pattern1,pattern2', file_type='tfrecord')
        ds = ds_params.Instantiate()
        ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)
        with tf.Session() as sess:
            ret.data = sess.run([ret.data])

        self.assertAllEqual(ret.data, [[b'tfrecord:pattern1,pattern2']])
Esempio n. 16
0
    def testSimpleDataSourceSucceedsWithStringInput(self):
        ds_params = datasource.SimpleDataSource.Params().Set(
            file_pattern='path_to_file')
        ds = ds_params.Instantiate()
        ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)
        with tf.Session() as sess:
            ret.data = sess.run([ret.data])

        self.assertCountEqual(sorted(ret.keys()), ['data'])
        self.assertAllEqual(ret.data, [[b'path_to_file']])
Esempio n. 17
0
 def testDitto(self):
   with tf.Session():
     ids, strs = self._enc.Encode('Ditto')
     self.assertEqual(u'▁ D itt o'.encode('utf-8'),
                      tf.strings.reduce_join(strs, separator=' ').eval())
     self.assertEqual(b'Ditto', self._enc.Decode(ids).eval())
     ids, strs = self._enc.Encode('Ditto Ditto')
     self.assertEqual(u'▁ D itt o ▁ D itt o'.encode('utf-8'),
                      tf.strings.reduce_join(strs, separator=' ').eval())
     self.assertEqual(b'Ditto Ditto', self._enc.Decode(ids).eval())
Esempio n. 18
0
    def testSimpleDataSourceSucceedsWithListInput(self):
        files = ['file1', 'file2']
        ds_params = datasource.SimpleDataSource.Params().Set(
            file_pattern=files)
        ds = ds_params.Instantiate()
        ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)
        with tf.Session():
            ret.data = self.evaluate([ret.data])

        self.assertAllEqual(ret.data, [[b'file1,file2']])
Esempio n. 19
0
    def testPrefixDataSourceSucceedsWithDirectory(self):
        ds_params = datasource.PrefixedDataSource.Params().Set(
            file_pattern='filename-*.tfrecord', file_pattern_prefix='/dir/')
        ds = ds_params.Instantiate()
        ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)

        with tf.Session():
            ret.data = self.evaluate([ret.data])

        self.assertAllEqual(ret.data, [[b'/dir/filename-*.tfrecord']])
Esempio n. 20
0
  def testSimpleDataSourceSucceedsWithListInput(self):
    files = ['file1', 'file2']
    ds_params = datasource.SimpleDataSource.Params().Set(file_pattern=files)
    ds = ds_params.Instantiate()
    ds.SetInputGenerator(TestInputGenerator.Params().Instantiate())
    with tf.Session():
      batch = ds.GetNext()
      ret = ds.GetMeta()
      ret.data = self.evaluate(batch.data)

    self.assertEqual(ret.data, b'file1,file2')
Esempio n. 21
0
 def testSimpleDataSourceFileInputSucceedsWithListInput(self):
   ds_params = datasource.SimpleDataSource.Params().Set(
       file_pattern=self.inputs, file_type='tfrecord')
   ds = ds_params.Instantiate()
   ds.SetInputGenerator(TestFileInputGenerator.Params().Instantiate())
   with tf.Session():
     batch = ds.GetNext()
     data = []
     for _ in range(2):
       data.append(self.evaluate(batch.data).tolist()[0])
     self.assertCountEqual(data, [0, 1])
def _CreateAsrFeatures():
    # First pass: extract transcription files.
    if False:  #os.path.exists(FLAGS.transcripts_filepath):
        trans = _LoadTranscriptionsFromFile()
    else:
        tf.logging.info('Running first pass on the fly')
        trans = _ReadTranscriptionsFromCSV()
    total_utts = len(trans)
    tf.logging.info('Total transcripts: %d', len(trans))
    tf_bytes = tf.placeholder(dtype=tf.string)
    log_mel = audio_lib.ExtractLogMelFeatures(tf_bytes)
    # Second pass: transcode the flac.
    file_obj = tf.io.gfile.GFile(FLAGS.input_tarball, mode='rb')
    tar = tarfile.open(fileobj=file_obj, mode='r:gz')
    n = 0
    recordio_writers = _OpenSubShards()
    tfconf = tf.config_pb2.ConfigProto()
    tfconf.gpu_options.allow_growth = True
    with tf.Session(config=tfconf) as sess:
        for tarinfo in tar:
            # We can actually decode essentially any audio format, but we
            # want to avoid non-audio data. Thus, this condition.
            if not (tarinfo.name.endswith('.flac')
                    or tarinfo.name.endswith('.wav')
                    or tarinfo.name.endswith('.mp3')):
                continue
            n += 1
            if n % FLAGS.num_shards != FLAGS.shard_id:
                continue
            f = tar.extractfile(tarinfo)
            fmt = tarinfo.name.split('.')[-1]
            uttid = tarinfo.name
            audio_bytes = f.read()
            f.close()
            try:
                wav_bytes = audio_lib.DecodeToWav(audio_bytes, fmt)
                frames = sess.run(log_mel, feed_dict={tf_bytes: wav_bytes})
            except Exception as e:
                # raise
                trans.pop(uttid)
                tf.logging.info(f'{uttid} FAILED featurization')
                continue
            assert uttid in trans, uttid
            num_words = len(trans[uttid])
            tf.logging.info('utt[%d]: %s [%d frames, %d chars]', n, uttid,
                            frames.shape[1], num_words)
            ex = _MakeTfExample(uttid, frames, trans[uttid])
            outf = _SelectRandomShard(recordio_writers)
            outf.write(ex.SerializeToString())
        tar.close()
    file_obj.close()
    _CloseSubShards(recordio_writers)
    tf.logging.info(f'Processed {len(trans)} / {total_utts}')
Esempio n. 23
0
    def testPrefixDataSourceSucceedsWithFileType(self):
        ds_params = datasource.PrefixedDataSource.Params().Set(
            file_pattern='filename-*.tfrecord',
            file_type='tfrecord',
            file_pattern_prefix='dir')
        ds = ds_params.Instantiate()
        ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)

        with tf.Session() as sess:
            ret.data = sess.run([ret.data])

        self.assertAllEqual(ret.data, [[b'tfrecord:dir/filename-*.tfrecord']])
Esempio n. 24
0
    def testPrefixDataSourceSucceedsWithDirectory(self):
        ds_params = datasource.PrefixedDataSourceWrapper.Params().Set(
            base_datasource=datasource.SimpleDataSource.Params().Set(
                file_pattern='filename-*.tfrecord', file_type=None),
            file_pattern_prefix='/dir/')
        ds = ds_params.Instantiate()
        ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)

        with tf.Session() as sess:
            ret.data = sess.run([ret.data])

        self.assertAllEqual(ret.data, [[b'/dir/filename-*.tfrecord']])
Esempio n. 25
0
    def _ComputeFinalMetrics(self,
                             classids=None,
                             difficulty=None,
                             distance=None,
                             num_points=None,
                             rotation=None):
        """Compute precision-recall curves as well as average precision.

    Args:
      classids: A list of N int32.
      difficulty: String in [easy, moderate, hard]. If None specified, all
        difficulty levels are permitted.
      distance: int32 specifying a binned Euclidean distance of the ground truth
        bounding box. If None is specified, all distances are selected.
      num_points: int32 specifying a binned number of laser points within the
        ground truth bounding box. If None is specified, all boxes are selected.
      rotation: int32 specifying a binned rotation within the ground truth
        bounding box. If None is specified, all boxes are selected.

    Returns:
      dict. Each entry in the dict is a list of C (number of classes) dicts
      containing mapping from metric names to individual results. Individual
      entries may be the following items.
      - scalars: A list of C (number of classes) dicts mapping metric
      names to scalar values.
      - curves: A list of C dicts mapping metrics names to np.float32
      arrays of shape [NumberOfPrecisionRecallPoints()+1, 2]. In the last
      dimension, 0 indexes precision and 1 indexes recall.
    """
        tf.logging.info('Computing final Waymo metrics.')
        assert classids is not None, 'classids must be supplied.'
        feed_dict = {}
        g = tf.Graph()
        scalar_fetches = []
        curve_fetches = []
        with g.as_default():
            for classid in classids:
                data = self._GetData(classid,
                                     difficulty=difficulty,
                                     distance=distance,
                                     num_points=num_points,
                                     rotation=rotation)
                metrics = self._BuildMetric(data, classid)
                scalar_fetches += [metrics.scalar_metrics]
                curve_fetches += [metrics.curve_metrics]
                feed_dict.update(metrics.feed_dict)

        with tf.Session(graph=g) as sess:
            results = sess.run([scalar_fetches, curve_fetches],
                               feed_dict=feed_dict)
        tf.logging.info('Finished computing final Waymo metrics.')
        return {'scalars': results[0], 'curves': results[1]}
Esempio n. 26
0
  def testChainingDataSourceSucceedsWithListInput(self):
    files = ['path_to_file1', 'path_to_file2']
    ds_params = datasource.ChainingDataSource.Params().Set(file_patterns=files)
    ds = ds_params.Instantiate()
    ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)

    with tf.Session() as sess:
      ret.data = sess.run([ret.data])

    self.assertCountEqual(
        sorted(ret.keys()), ['bprop_variable_filters', 'data'])
    self.assertAllEqual(ret.data, [[b'path_to_file1,path_to_file2']])
    self.assertCountEqual(ret.bprop_variable_filters, [''] * len(files))
Esempio n. 27
0
 def testTensorPartitioner(self):
     with tf.Session():
         w1 = tf.get_variable('w1', [255, 255], tf.float32)
         self.evaluate(tf.global_variables_initializer())
         partition_info = distributed_shampoo.PartitionConfig(200, 128)
         grad = tf.constant(w1.eval())
         metadata = distributed_shampoo.TensorPartitioner.partition_metadata(
             w1, partition_info)
         partitioned_grad = distributed_shampoo.TensorPartitioner.partition_tensor(
             w1, partition_info)
         reformed_grad = distributed_shampoo.TensorPartitioner.reform_tensor(
             partitioned_grad, metadata.num_splits_per_dim)
         self.assertAllCloseAccordingToType(reformed_grad, grad)
Esempio n. 28
0
  def testPrefixDataSourceSucceedsWithMultiplePatterns(self):
    ds_params = datasource.PrefixedDataSource.Params().Set(
        file_pattern='filename-*.tfrecord,other/file/pattern/*',
        file_type=None,
        file_pattern_prefix='/dir/')
    ds = ds_params.Instantiate()
    ret = ds.BuildDataSource(_MockDataSourceFromFilePattern)

    with tf.Session() as sess:
      ret.data = sess.run([ret.data])

    self.assertAllEqual(
        ret.data, [[b'/dir/filename-*.tfrecord,/dir/other/file/pattern/*']])
Esempio n. 29
0
    def _ComputeFinalMetrics(self,
                             classids=None,
                             difficulty=None,
                             distance=None,
                             num_points=None,
                             rotation=None):
        """Compute precision-recall curves as well as average precision.

    Args:
      classids: A list of N int32.
      difficulty: String in [easy, moderate, hard]. If None specified, all
        difficulty levels are permitted.
      distance: int32 specifying a binned Euclidean distance of the ground truth
        bounding box. If None is specified, all distances are selected.
      num_points: int32 specifying a binned number of laser points within the
        ground truth bounding box. If None is specified, all boxes are selected.
      rotation: int32 specifying a binned rotation within the ground truth
        bounding box. If None is specified, all boxes are selected.

    Returns:
      (dict, dict):

      - scalar_metrics: A list of C (number of clases) dicts mapping metric
        names to scalar values.
      - curve_metrics: A list of C dicts mapping metrics names to np.float32
        arrays of shape [NumberOfPrecisionRecallPoints()+1, 2]. In the last
        dimension, 0 indexes precision and 1 indexes recall.
    """
        assert classids is not None, 'classids must be supplied.'
        feed_dict = {}
        g = tf.Graph()
        scalar_fetches = []
        curve_fetches = []
        with g.as_default():
            for classid in classids:
                data = self._GetData(classid,
                                     difficulty=difficulty,
                                     distance=distance,
                                     num_points=num_points,
                                     rotation=rotation)
                scalars, curves, class_feed_dict = self._BuildMetric(
                    data, classid)
                scalar_fetches += [scalars]
                curve_fetches += [curves]
                feed_dict.update(class_feed_dict)

        with tf.Session(graph=g) as sess:
            results = sess.run([scalar_fetches, curve_fetches],
                               feed_dict=feed_dict)

        return results[0], results[1]
Esempio n. 30
0
    def testUniTransformerParallelFPropEntmax(self):
        length_dim = 4
        graph = tf.Graph()
        params = gshard_builder.UniTransformer.Params().Set(
            gated_gelu=False,
            gated_ffn_activation=tf.nn.relu,
            positional_embedding=False,
            dtype=tf.float32,
            name='transformer',
            parallel_ffn=True,
            hidden_dim_reshape_segments=2,
            conv_kernel_size=2,
            builder=gshard_builder.RecurrentDenseBuilderParallelDecode.Params(
            ).Set(
                device_mesh_shape=[1, 1],
                device_mesh=None,
                relative_attention_num_buckets=32,
                relative_attention_type='bias',
                relative_attention_max_distance=128,
                dtype=tf.float32,
                num_devices=1,  # we call .Split num_devices on axis 0 (batch)
                relative_attention_use_universal_1d_position=True,
                model_dim=32,
                model_dim_reshape_segments=2,
                attention_num_memory_heads=1,
                proj_weight_hdim=2,
                attention_num_heads=8,
                ff_dim=128,
                attention_key_value_dim=8,
                attention_combine_dims=True),
            batch_size=32,
            sequence_length=length_dim,
            num_transformer_layers=2,
            aux_loss_coef=0.0,
            loss_denominator=None,
            label_smoothing=0,
            vocab_size=128,
            max_length=length_dim,
            use_entmax=True)
        with graph.as_default():
            py_utils.GetOrCreateGlobalStepVar()
            params.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0)
            tf.random.set_seed(24332)
            model = params.Instantiate()

        with tf.Session(graph=graph) as sess:
            input_batch = self._PreLoadInput()
            loss = model.FPropDefaultTheta(input_batch)[0]['loss'][0]
            sess.run(tf.global_variables_initializer())
            loss_eval = sess.run(loss)
            test_utils.CompareToGoldenSingleFloat(self, 5.146667, loss_eval)