Exemplo n.º 1
0
    def testEagerEMACheckpointCompatibility(self):
        self.assertTrue(tf.executing_eagerly())
        cfg = model_registry.GetParams('test.LinearModelParams', 'Train')
        # Use non-zero learning rate so that the weights are updated
        cfg.task.train.learner[0].learning_rate = 0.1
        cfg.task.train.learner[1].learning_rate = 0.1

        eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1')
        eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2')
        mdl = cfg.Instantiate()

        @tf.function
        def _Update():
            with py_utils.GradientTape(persistent=True):
                mdl.ConstructFPropBPropGraph()

        # Step 1
        _Update()
        # Save V1 checkpoints at step 1.
        ckpt_v1 = checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl)
        ckpt_v1.Save(gsteps=1)

        ema = mdl.ema
        model_to_ema_map = _GetModelEMAVariablePairs(mdl, ema)
        model_to_ema_map_snapshot_step1 = {
            k: v.value()
            for k, v in model_to_ema_map.items()
        }

        # Step 2
        _Update()
        # Save V2 checkpoints at step 2.
        ckpt_v2 = checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl)
        ckpt_v2.Save(gsteps=2)

        model_to_ema_map = _GetModelEMAVariablePairs(mdl, ema)
        model_to_ema_map_snapshot_step2 = {
            k: v.value()
            for k, v in model_to_ema_map.items()
        }

        with cluster_factory.SetEval(True):
            # Restores variables to values saved in `eager_v1_logdir`
            ckpt_v1.Restore()
        # Verify that the EMA variables from V1 checkpoints at step 1 successfully
        # overwrite the model variables.
        for v in mdl.variables:
            if v.ref() in model_to_ema_map_snapshot_step1:
                self.assertAllEqual(v,
                                    model_to_ema_map_snapshot_step1[v.ref()])

        with cluster_factory.SetEval(True):
            # Restores variables to values saved in `eager_v2_logdir`
            ckpt_v2.Restore()
        # Verify that the EMA variables from V2 checkpoints at step 2 successfully
        # overwrite the model variables.
        for v in mdl.variables:
            if v.ref() in model_to_ema_map_snapshot_step2:
                self.assertAllEqual(v,
                                    model_to_ema_map_snapshot_step2[v.ref()])
Exemplo n.º 2
0
 def BuildTpuSubgraph(self):
     tf.logging.info('DecodeProgram BuildTpuSubGraph')
     py_utils.ResetStepSeed()
     self.spmd = self._task_params.input.use_partitioned_infeed_queue
     with cluster_factory.SetEval(True):
         self._CompileDecodeLoop()
     return
Exemplo n.º 3
0
 def testLeftContext(self, testonly_skip_norm_layers=False, norm_type='ln'):
   with flagsaver.flagsaver(testonly_skip_norm_layers=testonly_skip_norm_layers
                           ), cluster_factory.SetEval(True):
     assert norm_type in ('ln', 'gn')
     input_dim, kernel = 2, 3
     self._TestStreamStepHelper(
         num_heads=2, input_dim=input_dim, kernel=kernel, norm_type=norm_type)
Exemplo n.º 4
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()

        with cluster_factory.SetEval(True):
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._task_params.Instantiate()
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def _DecodeFn():
                """Decode call to be compiled for TPU."""
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._model.InstantiateVariables()
                    input_batch = self._task.input.TpuDequeueBatch()
                    metrics_dict = self._task.Decode(input_batch)
                self.metrics_nm = py_utils.NestedMap(metrics_dict)
                return self.metrics_nm.Flatten()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            _DecodeFn,
            num_shards=self.data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Exemplo n.º 5
0
  def BuildTpuSubgraph(self):
    tf.logging.info('EvalProgram BuildTpuSubGraph')
    with cluster_factory.SetEval(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism
      with cluster_factory.SetImmediatelyInstantiateVariables(False):
        self._model = self._InstantiateTaskModel(self._task_params)
      self._task = self._model.GetTask()
      self._task.input.InstantiateVariables()
      self._task.input.CreateTpuEnqueueOps()
      self._init_input_ops = self._task.input.InitOps()

      # XLA thinks self.TpuEvalLoop() requires 1 argument due to self
      # Trick it with wrapper function
      def TpuEvalLoopWrapper():
        return self.TpuEvalLoop()

      self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
          TpuEvalLoopWrapper,
          num_shards=data_parallelism,
          device_assignment=py_utils.GetTpuDeviceAssignment())
      self._task.input.CreateTpuEmbeddingEnqueueOps(mode_override='inference')

      # Get metric result from a single replica; they are all same here.
      self.tpu_ops = [[t[0] for t in batch_parallel_res]]
      return self.tpu_ops
Exemplo n.º 6
0
    def test_stacked_conformer_layer(self, batch_size, seq_len, num_layers,
                                     kernel_size, input_dims, model_dims,
                                     atten_num_heads, dropout_prob):
        p = conformers.StackedConformer.Params().Set(name='conformer',
                                                     input_dims=input_dims,
                                                     model_dims=model_dims,
                                                     num_layers=2)
        p.conformer_tpl.atten_num_heads = atten_num_heads
        p.conformer_tpl.kernel_size = kernel_size
        p.conformer_tpl.dropout_prob = dropout_prob

        stacked_conformer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = stacked_conformer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 2, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)

        context_p = base_layer.JaxContext.Params().Set(do_eval=True)

        with cluster_factory.SetEval(True):
            output = test_utils.apply(
                stacked_conformer,
                initial_vars,
                stacked_conformer.fprop,
                inputs,
                paddings,
                context_p=context_p,
            )

        self.assertEqual(output.shape, (batch_size, seq_len, model_dims))
Exemplo n.º 7
0
    def testImageModuleV2(self):

        # Create a fake image encoder module in lieu of having the test download a
        # real one from tf-hub.
        export_dir = self.create_tempdir().full_path

        with self.session() as sess:
            encoder = FakeTF2ImageModule(output_dim=42)
            sess.run(tf.global_variables_initializer())
            tf.saved_model.save(encoder, export_dir)

        params = tf_hub_layers.ImageModuleV2.Params().Set(
            name='foo', module_path=export_dir)
        layer = params.Instantiate()
        images = tf.ones([2, 24, 24, 3])
        features = layer(images)
        self.assertEqual([2, 42], features.shape.as_list())

        with self.session() as sess:
            sess.run(tf.global_variables_initializer())
            # Verify that calling the layer in train mode (lingvo's default) causes
            # the module's update ops to run.
            counter = layer.theta['foo/counter']
            self.assertEqual(0, counter.eval())
            _ = sess.run(features)
            _ = sess.run(features)
            self.assertEqual(2, counter.eval())

            # In eval mode, the layer should call the underlying module with
            # `training=False` and thus not run the update ops.
            with cluster_factory.SetEval(True):
                features = layer(images)
                _ = sess.run(features)
                _ = sess.run(features)
            self.assertEqual(2, counter.eval())
Exemplo n.º 8
0
 def testCommon(self,
                testonly_skip_norm_layers=False,
                norm_type='ln',
                num_groups=2,
                stride=1,
                layer_order='conv_before_mhsa',
                has_lconv='depthwise',
                has_fflayer_start=True,
                right_context=0):
   assert norm_type in ('ln', 'gn'), norm_type
   kwargs = dict(
       input_dim=8,
       kernel=3,
       layer_order=layer_order,
       num_heads=2,
       left_context=3,
       right_context=right_context,
       ffn_dim=4,
       norm_type=norm_type,
       has_lconv=has_lconv,
       has_fflayer_start=has_fflayer_start,
       num_groups=num_groups)
   kwargs['tol'] = 1e-5
   with cluster_factory.SetEval(True), flagsaver.flagsaver(
       testonly_skip_norm_layers=testonly_skip_norm_layers):
     self._TestStreamStepHelper(**kwargs)
Exemplo n.º 9
0
 def _DecodeFn():
     with py_utils.OpportunisticVariableReuseScope(True):
         with cluster_factory.SetEval(True):
             self._model = self._task_params.Instantiate()
             self._model_task = self._model.GetTask()
             input_batch = self._model_task.GetInputBatch()
             metrics_dict = self._model_task.Decode(input_batch)
             self.metrics_nm = py_utils.NestedMap(metrics_dict)
             return self.metrics_nm.Flatten()
Exemplo n.º 10
0
 def _DecodeFn():
     """Decode call to be compiled for TPU."""
     with py_utils.OpportunisticVariableReuseScope(True):
         with cluster_factory.SetEval(True):
             self._decode_model.InstantiateVariables()
             input_batch = self._decode_task.input.TpuDequeueBatch()
             metrics_dict = self._decode_task.Decode(input_batch)
     self.metrics_nm = py_utils.NestedMap(metrics_dict)
     return self.metrics_nm.Flatten()
Exemplo n.º 11
0
    def testStreamStep(self, testonly_skip_norm_layers=False, norm_type='ln'):
        with flagsaver.flagsaver(
                testonly_skip_norm_layers=testonly_skip_norm_layers
        ), cluster_factory.SetEval(True):
            assert norm_type in ('ln', 'gn')
            batch, max_seqlen, input_dim, kernel = 2, 8, 2, 3
            p = conformer_layer.LConvLayer.CommonParams(input_dim=input_dim,
                                                        is_causal=True,
                                                        kernel_size=kernel)
            if norm_type == 'ln':
                p.conv_norm_layer_tpl = lingvo_layers.LayerNorm.Params()
            else:
                p.conv_norm_layer_tpl = bn_layers.GroupNormLayer.Params().Set(
                    num_groups=2, cumulative=True)
            p.name = 'lconv'

            l = p.Instantiate()
            init_op = tf.global_variables_initializer()

            np.random.seed(None)
            inputs = np.random.normal(
                0.1, 0.5, [batch, max_seqlen, input_dim]).astype(np.float32)
            print(f'np.sum(inputs): {np.sum(inputs)}')
            inputs = tf.convert_to_tensor(inputs)

            seqlen = np.random.randint(low=1,
                                       high=max_seqlen + 1,
                                       size=(batch, ),
                                       dtype=np.int32)
            print(repr(seqlen))
            seqlen = tf.convert_to_tensor(seqlen)
            paddings = py_utils.PaddingsFromLengths(seqlen, max_seqlen)
            base_outputs, _ = l.FProp(l.theta, inputs, paddings)
            base_outputs *= tf.expand_dims(1. - paddings, -1)

            outputs = []
            state = l.zero_state(batch)
            for i in range(max_seqlen):
                output, _, state = l.StreamStep(l.theta, inputs[:,
                                                                i:(i + 1), :],
                                                paddings[:, i:(i + 1)], state)
                outputs.append(output)
            # [b, t, d]
            outputs = tf.concat(outputs, axis=1)
            outputs *= tf.expand_dims(1. - paddings, -1)

            with self.session(use_gpu=False) as sess:
                sess.run(init_op)
                expected, actual = sess.run([base_outputs, outputs])
                print(repr(expected))
                print(repr(actual))
                print(f'np.sum(np.abs(expected)): {np.sum(np.abs(expected))}')
                print(f'np.sum(np.abs(actual)): {np.sum(np.abs(actual))}')
                self.assertAllClose(expected, actual)
    def testPreprocessor(self, use_eval_mode):
        params = image_preprocessor.ImagePreprocessor.Params()

        encoded_images = _EncodeRandomJpegs(sizes=[(24, 42), (17, 19)])
        preprocessor = params.Instantiate()
        with cluster_factory.SetEval(use_eval_mode):
            images = preprocessor(encoded_images)
            self.assertEqual(use_eval_mode, preprocessor.do_eval)
            self.assertAllEqual(
                encoded_images.shape + params.output_image_size + [3],
                images.shape)
Exemplo n.º 13
0
    def BuildTpuSubgraph(self):
        tf.logging.info('EvalProgram BuildTpuSubGraph')
        with cluster_factory.SetEval(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._InstantiateTaskModel(self._task_params)
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def TpuEvalStep(*args):
                """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Summed eval metrics.
        """
                with tf.name_scope('tpu_eval'):
                    with py_utils.OpportunisticVariableReuseScope(True):
                        self._model.InstantiateVariables()
                        self._model.ConstructFPropGraph()
                    per_step_eval_metrics = self._eval_metrics.SetMetrics(
                        self._task.eval_metrics, args)
                    summed_metrics = []
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                    return summed_metrics

            @tpu_function.on_device_training_loop
            def TpuEval():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuEvalStep,
                    inputs=self._eval_metrics.initial_values,
                    name='eval_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
                TpuEval,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())

            self._task.input.CreateTpuEmbeddingEnqueueOps(
                mode_override='inference')

            # Get metric result from a single replica; they are all same here.
            self.tpu_ops = [[t[0] for t in batch_parallel_res]]

            return self.tpu_ops
Exemplo n.º 14
0
 def _DecodeFn():
     """Decode call to be compiled for TPU."""
     with py_utils.OpportunisticVariableReuseScope(True):
         with cluster_factory.SetEval(True):
             self._model = self._task_params.Instantiate()
             self._model_task = self._model.GetTask()
             self._model_task.AddChild('input', self._input)
             input_batch = self._model_task.input_generator.TpuDequeueBatch(
             )
             metrics_dict = self._model_task.Decode(input_batch)
             self.metrics_nm = py_utils.NestedMap(metrics_dict)
             return self.metrics_nm.Flatten()
Exemplo n.º 15
0
    def testDoEval(self):
        p = builder.Base.Params().Instantiate()._Dropout('dropout', 0.5)

        g = tf.Graph()
        with g.as_default():
            l = p.Instantiate()
            x = tf.random.normal(shape=[16, 16])
            # FProp three times each with different do_eval mode.
            with cluster_factory.SetEval(mode=None):
                a = l.FPropDefaultTheta(x)
            with cluster_factory.SetEval(mode=False):
                b = l.FPropDefaultTheta(x)
            with cluster_factory.SetEval(mode=True):
                c = l.FPropDefaultTheta(x)

        with self.session(graph=g) as sess:
            x, a, b, c = sess.run([x, a, b, c])

        self.assertGreater(np.linalg.norm(x - a), 0)
        self.assertGreater(np.linalg.norm(x - b), 0)
        self.assertAllEqual(x, c)
Exemplo n.º 16
0
 def testStackingLayerWithRightContext(self):
     tf.random.set_seed(2021)
     kwargs = dict(input_dim=8,
                   kernel=3,
                   num_heads=2,
                   left_context=6,
                   right_context=3,
                   ffn_dim=4,
                   stride=2,
                   layer_order='mhsa_before_conv',
                   num_layers=3)
     with cluster_factory.SetEval(True):
         self._TestRightContextStackingLayersHelper(**kwargs)
Exemplo n.º 17
0
 def _DecodeFn():
   """Decode call to be compiled for TPU."""
   with py_utils.OpportunisticVariableReuseScope(True):
     with cluster_factory.SetEval(True):
       self._model = self._task_params.Instantiate()
       self._model_task = self._model.GetTask()
       if py_utils.use_tpu():
         input_batch = self._model_task.input_generator.CreateTpuFeeds()
       else:
         input_batch = self._model_task.input_generator.SplitInputBatch(
             self.cluster.num_splits_per_client)
       metrics_dict = self._model_task.Decode(input_batch)
       self.metrics_nm = py_utils.NestedMap(metrics_dict)
       return self.metrics_nm.Flatten()
Exemplo n.º 18
0
            def TpuEvalStep(*args):
                """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Per-step eval metrics.
        """
                with cluster_factory.SetEval(True):
                    self._model = self._task_params.Instantiate()
                    self._model.ConstructFPropGraph()
                    per_step_eval_metrics = self._eval_metrics.SetMetrics(
                        self._model.GetTask().eval_metrics, args)
                    return per_step_eval_metrics
Exemplo n.º 19
0
            def TpuEvalStep(*args):
                """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Summed eval metrics.
        """
                with cluster_factory.SetEval(True):
                    self._model = self._task_params.Instantiate()
                    self._task = self._model.GetTask()
                    self._task.AddChild('input', self._input)

                    self._model.ConstructFPropGraph()
                    per_step_eval_metrics = self._eval_metrics.SetMetrics(
                        self._task.eval_metrics, args)
                    summed_metrics = []
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                    return summed_metrics
Exemplo n.º 20
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()
        device_assignment = py_utils.GetTpuDeviceAssignment()
        self.spmd = self._task_params.input.use_partitioned_infeed_queue
        with cluster_factory.SetEval(True):
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._task_params.Instantiate()
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def _DecodeStep():
                """Decode call to be compiled for TPU."""
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._model.InstantiateVariables()
                    input_batch = self._task.input.TpuDequeueBatch()
                    metrics_dict = self._task.Decode(input_batch)
                self.metrics_nm = py_utils.NestedMap(metrics_dict)
                device = tpu.core(0) if self.spmd else ''
                with tf.device(device):
                    outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                        self.metrics_nm.Flatten())
                    return [outfeed_enqueue]

        @tpu_function.on_device_training_loop
        def DecodeLoopFn():
            return tpu_training_loop.repeat(self._steps_per_loop,
                                            _DecodeStep,
                                            inputs=[])

        self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
            DecodeLoopFn,
            num_shards=self.data_parallelism,
            device_assignment=device_assignment)
        # Get a list of outfeed ops.
        self.metrics = self._OutfeedDequeue()
        # Pack the list of outfeed ops with structure in self.metrics_nm.
        self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics)
        return
Exemplo n.º 21
0
    def testImageModule(self):
        batch_size = 2
        image_size = 42
        feature_dim = 64

        export_dir = self.create_tempdir().full_path

        ExportFakeTF1ImageModule(input_image_height=image_size,
                                 input_image_width=image_size,
                                 output_feature_dim=feature_dim,
                                 export_path=export_dir)

        params = tf_hub_layers.ImageModule.Params().Set(name='image_module',
                                                        module_path=export_dir)
        layer = params.Instantiate()
        images = tf.zeros([batch_size, image_size, image_size, 3],
                          dtype=tf.float32)
        training_mode_features = layer(images)
        with cluster_factory.SetEval(True):
            eval_mode_features = layer(images)

        self.assertAllEqual([batch_size, feature_dim],
                            training_mode_features.shape)
        self.assertAllEqual([batch_size, feature_dim],
                            eval_mode_features.shape)

        # Check that update ops are run when the layer output is used during
        # training.
        count_variable = [v for v in layer.variables if v.name == 'count:0'][0]
        count_value = count_variable.read_value()

        with self.session() as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            self.assertEqual(0, sess.run(count_value))
            sess.run(training_mode_features)
            self.assertEqual(1, sess.run(count_value))

            sess.run(eval_mode_features)
            self.assertEqual(1, sess.run(count_value))
Exemplo n.º 22
0
    def testLayerStackSummary(self):
        # In this test we very that summaries created inside stack layers
        # are processed properly with and without RepeatedLayer
        model_dim = 4
        num_heads = 2
        d_kv = 2
        d_ff = 8
        num_experts = 2
        builder = gshard_builder.DenseBuilder.Params().Set(
            deterministic_dropout=True,
            dtype=tf.float32,
            relative_attention_type='bias',
            model_dim=model_dim,
            attention_num_heads=num_heads,
            attention_combine_dims=True,
            attention_num_memory_heads=1,
            model_dim_reshape_segments=None,
            ff_dim=d_ff,
            moe_hidden_dim=d_ff,
            e_dim=num_experts,
            c_dim=1,
            num_groups=num_experts,
            num_devices=num_experts,
            attention_key_value_dim=d_kv).Instantiate()

        def _GetOutputs(enc, dec):
            x, seg_id, pos_id = self._GetInputs()
            enc_inputs = py_utils.NestedMap(vec=x,
                                            segment_id=seg_id,
                                            segment_pos=pos_id,
                                            aux_loss=tf.constant(0.0))
            enc_outs = enc.FPropDefaultTheta(enc_inputs)
            dec_inputs = py_utils.NestedMap(
                vec=x,
                segment_id=seg_id,
                segment_pos=pos_id,
                encoder_output=enc_outs.vec,
                encoder_segment_id=tf.zeros_like(seg_id),
                encoder_segment_pos=tf.zeros_like(pos_id),
                aux_loss=enc_outs.aux_loss)
            return dec.FPropDefaultTheta(dec_inputs).vec

        # Build a graph with RepeatLayer unrolled.
        g = tf.Graph()
        with g.as_default(), tpu_summary.context(), cluster_factory.SetEval(
                mode=True):
            tf.random.set_seed(None)
            enc = builder.EncoderLayerStack(
                'encoder',
                sub_layers=[builder.DenseReluDense('ffw')],
                num=2,
                use_repeat_layer=True).Instantiate()
            dec = builder.DecoderLayerStack(
                'decoder',
                sub_layers=[builder.MoE('moe', decoder=True)],
                num=2,
                use_repeat_layer=True).Instantiate()
            rep_unroll_out = _GetOutputs(enc, dec)
            rep_unroll_summary = tpu_summary.merge_all()

        expected_rep_unroll_summary = [
            'index_1/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating',
            'index_1/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating',
            'over_capacity_1_ratio/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1_ratio/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_2_ratio/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2_ratio/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'top1_expert/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating',
            'top1_expert/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating'
        ]
        self.assertCountEqual(expected_rep_unroll_summary, rep_unroll_summary)

        tf.Session.reset(target='')
        with tf.Session(graph=g) as sess:
            sess.run(tf.global_variables_initializer())
            rep_unroll_out, rep_unroll_summary = sess.run(
                [rep_unroll_out, rep_unroll_summary])
            var_values = sess.run(tf.trainable_variables())
        # Build a graph without RepeatLayer.
        g = tf.Graph()
        with g.as_default(), tpu_summary.context():
            tf.random.set_seed(None)
            enc = builder.EncoderLayerStack('encoder',
                                            sub_layers=[
                                                builder.DenseReluDense('ffw')
                                            ],
                                            num=2).Instantiate()
            dec = builder.DecoderLayerStack(
                'decoder',
                sub_layers=[builder.MoE('moe', decoder=True)],
                num=2).Instantiate()
            dec_out = _GetOutputs(enc, dec)
            dec_summary = tpu_summary.merge_all()

        expected_dec_summary = [
            'index_1/decoder_1/layer_000/moe/ffw/compute_gating',
            'index_1/decoder_1/layer_001/moe/ffw/compute_gating',
            'over_capacity_1_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity',
            'over_capacity_2_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity_1',
            'top1_expert/decoder_1/layer_000/moe/ffw/compute_gating',
            'top1_expert/decoder_1/layer_001/moe/ffw/compute_gating'
        ]
        self.assertCountEqual(expected_dec_summary, dec_summary)

        tf.Session.reset(target='')
        with tf.Session(graph=g) as sess:
            tf_vars = [
                enc.vars.layer_000.ln.w.scale, enc.vars.layer_000.ffw.w.wi,
                enc.vars.layer_000.ffw.w.wo, enc.vars.layer_001.ln.w.scale,
                enc.vars.layer_001.ffw.w.wi, enc.vars.layer_001.ffw.w.wo,
                enc.vars.final_layer_norm.w.scale,
                dec.vars.layer_000.ln.w.scale, dec.vars.layer_000.moe.moe.wi,
                dec.vars.layer_000.moe.moe.wo,
                dec.vars.layer_000.moe.ffw.top_2_gating.w,
                dec.vars.layer_001.ln.w.scale, dec.vars.layer_001.moe.moe.wi,
                dec.vars.layer_001.moe.moe.wo,
                dec.vars.layer_001.moe.ffw.top_2_gating.w,
                dec.vars.final_layer_norm.w.scale
            ]
            for val, var in zip(var_values, tf_vars):
                sess.run(tf.assign(var, val))
            dec_out, dec_summary = sess.run([dec_out, dec_summary])
            self.assertAllClose(dec_out, rep_unroll_out)

            for name, alt_name in zip(expected_dec_summary,
                                      expected_rep_unroll_summary):
                self.assertAllClose(dec_summary[name],
                                    rep_unroll_summary[alt_name])
Exemplo n.º 23
0
 def SetEval(self, mode):
     return cluster_factory.SetEval(mode=mode)
Exemplo n.º 24
0
    def test_conformer_layer(self, batch_size, seq_len, kernel_size,
                             input_dims, model_dims, atten_num_heads,
                             dropout_prob):
        # Lingvo TF layers only use dropout on FF and Attention layers
        p = conformers.Conformer.Params().Set(
            name='jax_conformer_layer',
            input_dims=input_dims,
            conv_residual_dropout=0.0,
            atten_residual_dropout=dropout_prob,
            ffn_residual_dropout=dropout_prob,
            atten_dropout=dropout_prob,
            ffn_relu_dropout=dropout_prob,
            kernel_size=kernel_size,
            model_dims=model_dims,
            atten_num_heads=atten_num_heads)
        conformer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = conformer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)

        def GetPaddingfromLength(length):
            idx = np.tile(np.arange(seq_len), [batch_size, 1])
            return (idx >= np.expand_dims(length, -1)).astype('float32')

        length = np.random.randint(seq_len // 2, seq_len, (batch_size, ))
        npy_paddings = GetPaddingfromLength(length).astype('float32')
        paddings = jnp.asarray(npy_paddings)

        context_p = base_layer.JaxContext.Params().Set(do_eval=True)

        output = test_utils.apply(conformer,
                                  initial_vars,
                                  conformer.fprop,
                                  inputs,
                                  paddings,
                                  context_p=context_p)
        # Test whether tf Conformer layer returns the same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = test_utils.replace_jax_conformer_layer_vars_to_tf(
            initial_vars)

        tf_p = conformer_layer.ConformerLayer.CommonParams(
            input_dim=input_dims,
            dropout_prob=dropout_prob,
            atten_num_heads=atten_num_heads,
            kernel_size=kernel_size,
            fflayer_hidden_dim=model_dims * p.ffn_dim_multiplier,
            use_relative_atten=False,
            fflayer_residual_weight=0.5).Set(name='tf_conformer')
        tf_p.trans_atten_tpl = tf_p.trans_atten_tpl.Set(hidden_dim=model_dims)

        tf_conformer = tf_p.Instantiate()
        with cluster_factory.SetEval(True):
            tf_output = tf_conformer.FProp(
                tf_initial_vars,
                py_utils.NestedMap(features=tf.constant(inputs,
                                                        dtype=tf.float32),
                                   paddings=tf.constant(npy_paddings,
                                                        dtype=tf.float32)))
        np_output = to_np(output)
        tf_np_output = to_np(tf_output.features)
        self.assertAllClose(tf_np_output, np_output, atol=1e-5)
Exemplo n.º 25
0
def main(unused_argv):
    decoder = FeatureNeighborhoodModelDecoder()
    with cluster_factory.SetEval(mode=True):
        decoder.decode()
Exemplo n.º 26
0
 def BuildTpuSubgraph(self):
     tf.logging.info('DecodeProgram BuildTpuSubGraph')
     py_utils.ResetStepSeed()
     with cluster_factory.SetEval(True):
         self._CompileDecodeFn()
     return None
Exemplo n.º 27
0
    def BuildTpuSubgraph(self):
        if self._ml_perf_log:
            mlp_log.mlperf_print('global_batch_size',
                                 self._ml_perf.global_batch_size)
            mlp_log.mlperf_print('max_sequence_length',
                                 self._ml_perf.max_sequence_length)
            mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name)
            mlp_log.mlperf_print('opt_base_learning_rate',
                                 self._ml_perf.base_learning_rate)
            mlp_log.mlperf_print('opt_learning_rate_warmup_steps',
                                 self._ml_perf.warmup_steps)

        self._eval_metrics = metrics.TpuEvalMetrics()
        data_parallelism = self.data_parallelism
        with cluster_factory.SetImmediatelyInstantiateVariables(False):
            self._train_model = self._train_task_params.Instantiate()
        self._train_task = self._train_model.GetTask()
        self._train_task.input.InstantiateVariables()
        self._train_task.input.CreateTpuEnqueueOps()
        self._model = self._train_model

        def TpuTrainStep():
            """Train a shard of a batch on a single TPU core.

      Do not calculate loss metrics.

      Returns:
       [train_op].
      """
            with py_utils.OpportunisticVariableReuseScope(True):
                self._train_model.InstantiateVariables()
                self._train_model.ConstructFPropBPropGraph()
            return [self._train_task.train_op]

        def TpuTrain():
            loop_result = tpu_training_loop.repeat(self._train_steps_per_loop,
                                                   TpuTrainStep,
                                                   inputs=[],
                                                   name='train_loop')
            return loop_result

        py_utils.ResetStepSeed()

        with cluster_factory.SetEval(True):
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._decode_model = self._decode_task_params.Instantiate()
            self._decode_task = self._decode_model.GetTask()
            self._decode_task.input.InstantiateVariables()
            self._decode_task.input.CreateTpuEnqueueOps()

            def _DecodeFn():
                """Decode call to be compiled for TPU."""
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._decode_model.InstantiateVariables()
                    input_batch = self._decode_task.input.TpuDequeueBatch()
                    metrics_dict = self._decode_task.Decode(input_batch)
                self.metrics_nm = py_utils.NestedMap(metrics_dict)
                return self.metrics_nm.Flatten()

        @tpu_function.on_device_training_loop
        def TrainAndDecode():
            with tf.control_dependencies([TpuTrain()]):
                return _DecodeFn()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            TrainAndDecode,
            num_shards=data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Exemplo n.º 28
0
    def testStreamStep(self,
                       testonly_skip_norm_layers=False,
                       norm_type='ln',
                       num_groups=2,
                       stride=1,
                       layer_order='conv_before_mhsa',
                       has_lconv='depthwise',
                       has_fflayer_start=True,
                       right_context=0):
        assert norm_type in ('ln', 'gn'), norm_type
        with flagsaver.flagsaver(
                testonly_skip_norm_layers=testonly_skip_norm_layers
        ), cluster_factory.SetEval(True):
            batch, max_seqlen, input_dim, kernel = 2, 16, 8, 3
            assert max_seqlen % stride == 0

            if layer_order == 'mhsa':
                kernel = None
            num_heads, left_context, ffn_dim = 2, 3, 4
            p = conformer_layer.ConformerLayer.CommonParams(
                input_dim=input_dim,
                is_causal=True,
                atten_num_heads=num_heads,
                atten_left_context=left_context,
                atten_right_context=right_context,
                use_relative_atten=False,
                fflayer_hidden_dim=ffn_dim,
                kernel_size=kernel,
                layer_order=layer_order)
            if norm_type == 'ln':
                p.lconv_tpl.conv_norm_layer_tpl = lingvo_layers.LayerNorm.Params(
                )
            else:
                p.lconv_tpl.conv_norm_layer_tpl = bn_layers.GroupNormLayer.Params(
                ).Set(num_groups=num_groups, cumulative=True)
            if not has_lconv:
                p.lconv_tpl = None
            elif has_lconv == 'conv2d':
                p.lconv_tpl.depthwise_conv_tpl = (
                    conv_layers_with_time_padding.CausalConv2DLayerWithPadding.
                    Params())
            else:
                assert has_lconv == 'depthwise'
            if not has_fflayer_start:
                p.fflayer_start_tpl = None
            p.name = 'conformer'

            l = p.Instantiate()
            init_op = tf.global_variables_initializer()

            np.random.seed(None)
            inputs = 5 * np.random.normal(
                0.1, 0.5, [batch, max_seqlen, input_dim]).astype(np.float32)
            print(f'np.sum(inputs): {np.sum(inputs)}')
            inputs = tf.convert_to_tensor(inputs)

            seqlen = np.random.randint(low=1,
                                       high=max_seqlen + 1,
                                       size=(batch, ),
                                       dtype=np.int32)
            print(f'seqlen: {seqlen}')
            seqlen = tf.convert_to_tensor(seqlen)
            paddings = py_utils.PaddingsFromLengths(seqlen, max_seqlen)

            base_output_map = l.FProp(
                l.theta, py_utils.NestedMap(features=inputs,
                                            paddings=paddings))
            base_outputs = base_output_map.features
            base_outputs *= tf.expand_dims(1. - paddings, -1)

            outputs = []
            state = l.zero_state(batch)
            for i in range(max_seqlen // stride +
                           int(math.ceil(right_context / stride))):
                if i < max_seqlen // stride:
                    step_inputs = inputs[:, stride * i:stride * (i + 1)]
                    step_paddings = paddings[:, stride * i:stride * (i + 1)]
                else:
                    step_inputs = tf.zeros_like(inputs[:, 0:stride])
                    step_paddings = tf.ones_like(paddings[:, 0:stride])
                output, _, state = l.StreamStep(l.theta, step_inputs,
                                                step_paddings, state)
                outputs.append(output)

            outputs = tf.concat(outputs, axis=1)
            outputs = outputs[:, right_context:][:, :max_seqlen]
            outputs *= tf.reshape(1. - paddings, [batch, max_seqlen, 1])

            with self.session(use_gpu=False) as sess:
                sess.run(init_op)
                expected, actual = sess.run([base_outputs, outputs])
                print(repr(expected))
                print(repr(actual))
                print(f'np.sum(np.abs(expected)): {np.sum(np.abs(expected))}')
                print(f'np.sum(np.abs(actual)): {np.sum(np.abs(actual))}')
                tol = 3.e-6 if testonly_skip_norm_layers else 2.e-5
                self.assertAllClose(expected, actual, atol=tol, rtol=tol)
Exemplo n.º 29
0
  def testStreamStep(self,
                     testonly_skip_norm_layers=False,
                     norm_type='ln',
                     num_groups=2,
                     stride=1,
                     layer_order='conv_before_mhsa',
                     has_lconv=True,
                     has_fflayer_start=True):
    assert norm_type in ('ln', 'gn'), norm_type
    with flagsaver.flagsaver(testonly_skip_norm_layers=testonly_skip_norm_layers
                            ), cluster_factory.SetEval(True):
      batch, max_seqlen, input_dim, kernel = 2, 16, 8, 3
      if layer_order == 'mhsa':
        kernel = None
      num_heads, left_context, ffn_dim = 2, 3, 4
      p = conformer_layer.ConformerLayer.CommonParams(
          input_dim=input_dim,
          is_causal=True,
          atten_num_heads=num_heads,
          atten_left_context=left_context,
          atten_right_context=0,
          use_relative_atten=False,
          fflayer_hidden_dim=ffn_dim,
          kernel_size=kernel,
          layer_order=layer_order)
      if norm_type == 'ln':
        p.lconv_tpl.conv_norm_layer_tpl = layers.LayerNorm.Params()
      else:
        p.lconv_tpl.conv_norm_layer_tpl = bn_layers.GroupNormLayer.Params().Set(
            num_groups=num_groups, cumulative=True)
      if not has_lconv:
        p.lconv_tpl = None
      if not has_fflayer_start:
        p.fflayer_start_tpl = None
      p.name = 'conformer'

      l = p.Instantiate()
      init_op = tf.global_variables_initializer()

      np.random.seed(None)
      inputs = 5 * np.random.normal(
          0.1, 0.5, [batch, max_seqlen, input_dim]).astype(np.float32)
      print(f'np.sum(inputs): {np.sum(inputs)}')
      inputs = tf.convert_to_tensor(inputs)

      seqlen = np.random.randint(
          low=1, high=max_seqlen + 1, size=(batch,), dtype=np.int32)
      print(repr(seqlen))
      seqlen = tf.convert_to_tensor(seqlen)
      paddings = py_utils.PaddingsFromLengths(seqlen, max_seqlen)

      base_output_map = l.FProp(
          l.theta, py_utils.NestedMap(features=inputs, paddings=paddings))
      base_outputs = base_output_map.features
      base_outputs *= tf.expand_dims(1. - paddings, -1)

      outputs = []
      state = l.zero_state(batch)
      for i in range(0, max_seqlen, stride):
        output, _, state = l.StreamStep(l.theta, inputs[:, i:(i + stride), :],
                                        paddings[:, i:(i + stride)], state)
        outputs.append(output)
      # [b, t, d]
      outputs = tf.concat(outputs, axis=1)
      outputs *= tf.expand_dims(1. - paddings, -1)

      with self.session(use_gpu=False) as sess:
        sess.run(init_op)
        expected, actual = sess.run([base_outputs, outputs])
        print(repr(expected))
        print(repr(actual))
        print(f'np.sum(np.abs(expected)): {np.sum(np.abs(expected))}')
        print(f'np.sum(np.abs(actual)): {np.sum(np.abs(actual))}')
        tol = 2.e-6 if testonly_skip_norm_layers else 2.e-5
        self.assertAllClose(expected, actual, atol=tol, rtol=tol)