コード例 #1
0
  def testLayerStack(self):
    model_dim = 4
    num_heads = 2
    d_kv = 2
    d_ff = 8
    builder = gshard_builder.DenseBuilder.Params().Set(
        dtype=tf.float32,
        relative_attention_type='bias',
        model_dim=model_dim,
        attention_num_heads=num_heads,
        attention_combine_dims=True,
        attention_num_memory_heads=1,
        model_dim_reshape_segments=2,
        ff_dim=d_ff,
        attention_key_value_dim=d_kv).Instantiate()

    def _GetInputs():
      x = tf.constant([[[.1, .2, .3, .4], [.3, .4, .5, .6], [.5, .6, .1, .2]],
                       [[.7, .8, .4, .5], [.9, .1, .2, .3], [.0, .9, .3, .7]]],
                      dtype=tf.float32)
      seg_id = tf.constant([[1, 1, 1], [1, 1, 1]], dtype=tf.int32)
      pos_id = tf.constant([[0, 1, 2], [0, 1, 2]], dtype=tf.int32)
      # Reshape with model_dim_reshape_segments = 2
      reshaped_x = tf.reshape(x, [2, 3, 2, -1])
      return reshaped_x, seg_id, pos_id

    def _GetOutputs(enc, dec):
      x, seg_id, pos_id = _GetInputs()
      enc_inputs = py_utils.NestedMap(
          vec=x,
          segment_id=seg_id,
          segment_pos=pos_id,
          aux_loss=tf.constant(0.0))
      enc_outs = enc.FPropDefaultTheta(enc_inputs)
      dec_inputs = py_utils.NestedMap(
          vec=x,
          segment_id=seg_id,
          segment_pos=pos_id,
          encoder_output=enc_outs.vec,
          encoder_segment_id=tf.zeros_like(seg_id),
          encoder_segment_pos=tf.zeros_like(pos_id),
          aux_loss=enc_outs.aux_loss)
      return dec.FPropDefaultTheta(dec_inputs).vec

    # Build a graph with RepeatLayer.
    g = tf.Graph()
    with g.as_default():
      tf.random.set_seed(None)
      enc = builder.EncoderLayerStack(
          'encoder',
          sub_layers=[builder.DenseReluDense('ffw')],
          num=2,
          use_repeat_layer=True).Instantiate()
      dec = builder.DecoderLayerStack(
          'decoder',
          sub_layers=[builder.DenseReluDense('ffw', decoder=True)],
          num=2,
          use_repeat_layer=True).Instantiate()
      rep_out = _GetOutputs(enc, dec)

    tf.Session.reset(target='')
    with tf.Session(graph=g) as sess:
      sess.run(tf.global_variables_initializer())
      rep_out = rep_out.eval(session=sess)
      var_values = sess.run(tf.trainable_variables())

    # Build a graph without RepeatLayer.
    g = tf.Graph()
    with g.as_default():
      tf.random.set_seed(None)
      enc = builder.EncoderLayerStack(
          'encoder', sub_layers=[builder.DenseReluDense('ffw')],
          num=2).Instantiate()
      dec = builder.DecoderLayerStack(
          'decoder',
          sub_layers=[builder.DenseReluDense('ffw', decoder=True)],
          num=2).Instantiate()
      dec_out = _GetOutputs(enc, dec)

    tf.Session.reset(target='')
    with tf.Session(graph=g) as sess:
      tf_vars = [
          enc.vars.layer_000.ln.w.scale, enc.vars.layer_000.ffw.w.wi,
          enc.vars.layer_000.ffw.w.wo, enc.vars.layer_001.ln.w.scale,
          enc.vars.layer_001.ffw.w.wi, enc.vars.layer_001.ffw.w.wo,
          enc.vars.final_layer_norm.w.scale, dec.vars.layer_000.ln.w.scale,
          dec.vars.layer_000.ffw.w.wi, dec.vars.layer_000.ffw.w.wo,
          dec.vars.layer_001.ln.w.scale, dec.vars.layer_001.ffw.w.wi,
          dec.vars.layer_001.ffw.w.wo, dec.vars.final_layer_norm.w.scale
      ]
      for val, var in zip(var_values, tf_vars):
        sess.run(tf.assign(var, val))
      dec_out = dec_out.eval(session=sess)
      self.assertAllClose(dec_out, rep_out)
コード例 #2
0
ファイル: gshard_builder_test.py プロジェクト: adis98/lingvo
  def testParallelDecSelfAttentionRelativeBiasFFN(self):
    model_dim = 4
    num_heads = 2
    d_kv = 2
    d_ff = 8
    builder = gshard_builder.DenseBuilder.Params().Set(
        dtype=tf.float32,
        relative_attention_type='bias',
        model_dim=model_dim,
        attention_num_heads=num_heads,
        attention_combine_dims=True,
        attention_num_memory_heads=1,
        model_dim_reshape_segments=2,
        ff_dim=d_ff,
        attention_key_value_dim=d_kv).Instantiate()

    def _GetInputs():
      x = tf.constant([[[.1, .2, .3, .4], [.3, .4, .5, .6], [.5, .6, .1, .2]],
                       [[.7, .8, .4, .5], [.9, .1, .2, .3], [.0, .9, .3, .7]]],
                      dtype=tf.float32)
      seg_id = tf.constant([[1, 1, 1], [1, 1, 1]], dtype=tf.int32)
      pos_id = tf.constant([[0, 1, 2], [0, 1, 2]], dtype=tf.int32)
      # Reshape with model_dim_reshape_segments = 2
      reshaped_x = tf.reshape(x, [2, 3, 2, -1])
      return reshaped_x, seg_id, pos_id

    # Build a graph with separate attention and ffn layers.
    # Naively compute the output by adding the outputs of the two directly.
    g = tf.Graph()
    with g.as_default():
      tf.random.set_seed(None)
      x, seg_id, pos_id = _GetInputs()
      atten = builder.DecSelfAttentionRelativeBias('atten').Instantiate()
      ffn = builder.DenseReluDenseGated('ffn', tf.nn.relu, True).Instantiate()
      y_atten, _ = atten.FPropDefaultTheta(x, seg_id, pos_id, tf.constant(0),
                                           tf.constant(0), tf.constant(0))
      y_ffn, _ = ffn.FPropDefaultTheta(x, seg_id, pos_id, tf.constant(0),
                                       tf.constant(0), tf.constant(0))
      y_exp = (y_atten + y_ffn) * (2.0**-0.5)
    tf.Session.reset(target='')
    with tf.Session(graph=g) as sess:
      sess.run(tf.global_variables_initializer())
      y_exp = y_exp.eval(session=sess)
      var_values = sess.run(tf.trainable_variables())

    # Build a graph with dedeciated parallel layer and load the variable values.
    # Expect output the same as the previous naive implementation.
    g = tf.Graph()
    with g.as_default():
      x, seg_id, pos_id = _GetInputs()
      parallel = builder.ParallelDecSelfAttentionRelativeBiasFFN(
          'parallel', tf.nn.relu, hidden_dim_reshape_segments=2).Instantiate()
      y_parallel, _ = parallel.FPropDefaultTheta(x, seg_id, pos_id,
                                                 tf.constant(0), tf.constant(0),
                                                 tf.constant(0))
    tf.Session.reset(target='')
    with tf.Session(graph=g) as sess:
      tf_vars = [
          parallel.vars.w_atten.wq, parallel.vars.w_atten.wk,
          parallel.vars.w_atten.wv, parallel.vars.w_atten.wo,
          parallel.vars.wrb.wrb, parallel.vars.w_fflayer.wi_0,
          parallel.vars.w_fflayer.wi_1, parallel.vars.w_fflayer.wo
      ]
      for val, var in zip(var_values, tf_vars):
        sess.run(tf.assign(var, val))
      y_parallel = y_parallel.eval(session=sess)
      self.assertAllClose(y_exp, y_parallel)
コード例 #3
0
    def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
        p = self.params

        load_op_list = []
        retrieve_op_list = []

        num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
        table_name = tpu_embedding_table.table_name
        slot_var_collections = [
            tpu_embedding_table.__class__.__name__ + '_vars'
        ]

        for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
            # The slot vars should be on the same device as the table var.
            device_name = tpu_embedding_table.GetDeviceName(host_id)
            with tf.device(device_name), py_utils.outside_all_rewrites():
                w_ada = py_utils.WeightParams(
                    shape=table_var.shape.as_list(),
                    init=py_utils.WeightInit.Constant(p.initial_accumulator),
                    dtype=p.dtype,
                    collections=slot_var_collections)
                var_name = tpu_embedding_table.GetVariableName(
                    host_id) + '/Adagrad'
                tpu_embedding_table.CreateVariable(var_name,
                                                   w_ada,
                                                   trainable=False)
                accumulator_var = tpu_embedding_table.vars[var_name]

                # Only the Trainer needs these ops.
                if py_utils.use_tpu():
                    # Remove the slot vars from the variable list to void copying them
                    # to TPU (by the tf.cast in tpu_embedding_table.theta).
                    # pylint: disable=protected-access
                    del tpu_embedding_table._private_vars[var_name]
                    del tpu_embedding_table._private_theta[var_name]
                    # pylint: enable=protected-access

                    # TPU Embedding load/retrieve ops need to be in the outer graph scope.
                    with tf.init_scope():
                        tf.logging.info('creating load and retrieve ops.')
                        load_parameters_op = (
                            tpu_embedding_lib.tpu_ops.
                            load_tpu_embedding_adagrad_parameters(
                                parameters=table_var,
                                accumulators=accumulator_var,
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        load_op_list.append(load_parameters_op)

                        retrieved_table, retrieved_accumulator = (
                            tpu_embedding_lib.tpu_ops.
                            retrieve_tpu_embedding_adagrad_parameters(
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                            tf.assign(table_var, retrieved_table),
                            tf.assign(accumulator_var, retrieved_accumulator))
                        retrieve_op_list.append(retrieve_parameters_op)

        return load_op_list, retrieve_op_list
コード例 #4
0
    def testBatchNormLayer(self):
        p = base_model.SingleTaskModel.Params()
        p.task = self.TestParams(layers.BatchNormLayer.Params().Set(dim=1))
        p.task.train.ema_decay = 0.9
        p.task.train.ema_decay_moving_vars = True
        model = p.Instantiate()
        self.assertIsNotNone(model.ema)
        task = model._task
        task._train_op = tf.no_op()
        task.ApplyExponentialMovingAverage(model.ema)

        layer = task.encoder
        self.assertLen(layer.vars, 4)
        for var in layer.vars.Flatten():
            self.assertIsNotNone(model.ema.average(var), msg=var.name)
        beta = layer.vars.beta
        mean = layer.vars.moving_mean

        global_step = 100
        beta_1 = np.asarray([.2])
        mean_1 = np.asarray([.03])
        beta_1_ema = beta_1 * .1
        mean_1_ema = mean_1 * .1
        with self.session() as sess:
            # Test EMA values.
            self.evaluate(tf.global_variables_initializer())
            self.evaluate(
                tf.assign(py_utils.GetOrCreateGlobalStepVar(), global_step))
            self.evaluate(tf.assign(beta, beta_1))
            self.evaluate(tf.assign(mean, mean_1))
            self.evaluate(task._post_train_ops)

            self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema],
                                self.evaluate([
                                    beta,
                                    model.ema.average(beta), mean,
                                    model.ema.average(mean)
                                ]))

            # Test checkpointer.
            train_dir = os.path.join(self.get_temp_dir(), 'testSaveRestore')
            os.mkdir(train_dir)
            saver = checkpointer.Checkpointer(train_dir, model)
            saver.Save(sess, model.global_step)

            self.assertTrue(
                os.path.isfile(
                    os.path.join(train_dir, 'ckpt-%08d.index' % global_step)))

        # Restore from ckpt in training mode.
        with self.session(graph=tf.Graph()) as sess:
            model = p.Instantiate()
            self.assertIsNotNone(model.ema)
            task = model._task
            task._train_op = tf.no_op()
            task.ApplyExponentialMovingAverage(model.ema)
            layer = task.encoder
            for var in layer.vars.Flatten():
                self.assertIsNotNone(model.ema.average(var), msg=var.name)
            beta = layer.vars.beta
            mean = layer.vars.moving_mean

            saver = checkpointer.Checkpointer(train_dir, model)
            saver.RestoreIfNeeded(sess)

            self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema],
                                self.evaluate([
                                    beta,
                                    model.ema.average(beta), mean,
                                    model.ema.average(mean)
                                ]))

        # Restore from ckpt in eval mode.
        with self.session(graph=tf.Graph()) as sess, self.SetEval(True):
            model = p.Instantiate()
            self.assertIsNotNone(model.ema)
            task = model._task
            # task._train_op = tf.no_op()
            # task.ApplyExponentialMovingAverage(model.ema)
            layer = task.encoder
            # for var in layer.vars.Flatten():
            #   self.assertIsNotNone(model.ema.average(var), msg=var.name)
            beta = layer.vars.beta
            mean = layer.vars.moving_mean

            saver = checkpointer.Checkpointer(train_dir, model)
            saver.RestoreIfNeeded(sess)

            # Both beta and mean should use the EMA value.
            self.assertAllClose([beta_1_ema, mean_1_ema],
                                self.evaluate([beta, mean]))
コード例 #5
0
    def testLayerStackSummary(self):
        # In this test we very that summaries created inside stack layers
        # are processed properly with and without RepeatedLayer
        model_dim = 4
        num_heads = 2
        d_kv = 2
        d_ff = 8
        num_experts = 2
        builder = gshard_builder.DenseBuilder.Params().Set(
            deterministic_dropout=True,
            dtype=tf.float32,
            relative_attention_type='bias',
            model_dim=model_dim,
            attention_num_heads=num_heads,
            attention_combine_dims=True,
            attention_num_memory_heads=1,
            model_dim_reshape_segments=None,
            ff_dim=d_ff,
            moe_hidden_dim=d_ff,
            e_dim=num_experts,
            c_dim=1,
            num_groups=num_experts,
            num_devices=num_experts,
            attention_key_value_dim=d_kv).Instantiate()

        def _GetOutputs(enc, dec):
            x, seg_id, pos_id = self._GetInputs()
            enc_inputs = py_utils.NestedMap(vec=x,
                                            segment_id=seg_id,
                                            segment_pos=pos_id,
                                            aux_loss=tf.constant(0.0))
            enc_outs = enc.FPropDefaultTheta(enc_inputs)
            dec_inputs = py_utils.NestedMap(
                vec=x,
                segment_id=seg_id,
                segment_pos=pos_id,
                encoder_output=enc_outs.vec,
                encoder_segment_id=tf.zeros_like(seg_id),
                encoder_segment_pos=tf.zeros_like(pos_id),
                aux_loss=enc_outs.aux_loss)
            return dec.FPropDefaultTheta(dec_inputs).vec

        # Build a graph with RepeatLayer unrolled.
        g = tf.Graph()
        with g.as_default(), tpu_summary.context(), cluster_factory.SetEval(
                mode=True):
            tf.random.set_seed(None)
            enc = builder.EncoderLayerStack(
                'encoder',
                sub_layers=[builder.DenseReluDense('ffw')],
                num=2,
                use_repeat_layer=True).Instantiate()
            dec = builder.DecoderLayerStack(
                'decoder',
                sub_layers=[builder.MoE('moe', decoder=True)],
                num=2,
                use_repeat_layer=True).Instantiate()
            rep_unroll_out = _GetOutputs(enc, dec)
            rep_unroll_summary = tpu_summary.merge_all()

        expected_rep_unroll_summary = [
            'index_1/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating',
            'index_1/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating',
            'over_capacity_1/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1_ratio/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1_ratio/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_2/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2_ratio/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2_ratio/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'top1_expert/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating',
            'top1_expert/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating'
        ]
        self.assertEqual(set(expected_rep_unroll_summary),
                         set(rep_unroll_summary))

        tf.Session.reset(target='')
        with tf.Session(graph=g) as sess:
            sess.run(tf.global_variables_initializer())
            rep_unroll_out, rep_unroll_summary = sess.run(
                [rep_unroll_out, rep_unroll_summary])
            var_values = sess.run(tf.trainable_variables())
        # Build a graph without RepeatLayer.
        g = tf.Graph()
        with g.as_default(), tpu_summary.context():
            tf.random.set_seed(None)
            enc = builder.EncoderLayerStack('encoder',
                                            sub_layers=[
                                                builder.DenseReluDense('ffw')
                                            ],
                                            num=2).Instantiate()
            dec = builder.DecoderLayerStack(
                'decoder',
                sub_layers=[builder.MoE('moe', decoder=True)],
                num=2).Instantiate()
            dec_out = _GetOutputs(enc, dec)
            dec_summary = tpu_summary.merge_all()

        expected_dec_summary = [
            'index_1/decoder_1/layer_000/moe/ffw/compute_gating',
            'index_1/decoder_1/layer_001/moe/ffw/compute_gating',
            'over_capacity_1/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity',
            'over_capacity_2/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity_1',
            'top1_expert/decoder_1/layer_000/moe/ffw/compute_gating',
            'top1_expert/decoder_1/layer_001/moe/ffw/compute_gating'
        ]
        self.assertCountEqual(expected_dec_summary, dec_summary)

        tf.Session.reset(target='')
        with tf.Session(graph=g) as sess:
            tf_vars = [
                enc.vars.layer_000.ln.w.scale, enc.vars.layer_000.ffw.w.wi,
                enc.vars.layer_000.ffw.w.wo, enc.vars.layer_001.ln.w.scale,
                enc.vars.layer_001.ffw.w.wi, enc.vars.layer_001.ffw.w.wo,
                enc.vars.final_layer_norm.w.scale,
                dec.vars.layer_000.ln.w.scale, dec.vars.layer_000.moe.moe.wi,
                dec.vars.layer_000.moe.moe.wo,
                dec.vars.layer_000.moe.ffw.top_2_gating.w,
                dec.vars.layer_001.ln.w.scale, dec.vars.layer_001.moe.moe.wi,
                dec.vars.layer_001.moe.moe.wo,
                dec.vars.layer_001.moe.ffw.top_2_gating.w,
                dec.vars.final_layer_norm.w.scale
            ]
            for val, var in zip(var_values, tf_vars):
                sess.run(tf.assign(var, val))
            dec_out, dec_summary = sess.run([dec_out, dec_summary])
            self.assertAllClose(dec_out, rep_unroll_out)

            for name, alt_name in zip(expected_dec_summary,
                                      expected_rep_unroll_summary):
                self.assertAllClose(dec_summary[name],
                                    rep_unroll_summary[alt_name])
コード例 #6
0
def variable_assign(var, new_value):
    return tf.assign(var, new_value, name=var.op.name + '_assign')
コード例 #7
0
ファイル: base_model_test.py プロジェクト: tensorflow/lingvo
 def PostTrainingStepUpdate(self):
   # We expect the training step to be done, so capture
   # the value of counter1 into counter2.
   return tf.assign(self.vars.counter2, self.vars.counter1)
コード例 #8
0
ファイル: quant_utils_test.py プロジェクト: lizheng-1/lingvo
  def testFakeQuantizationScheduleTraining(self):
    p = quant_utils.FakeQuantizationSchedule.Params()
    p.clip_start_step = 5
    p.clip_end_step = 10
    p.quant_start_step = 15
    p.start_cap = 6.0
    p.end_cap = 1.0
    with self.session() as sess:
      cc_schedule = p.Instantiate()
      tf.global_variables_initializer().run()
      # Step 0: No clipping.
      self.assertAllClose(
          self._ClipExample(cc_schedule, 100.0), (-100.0, 100.0))
      self.assertAllClose(
          self._ClipExample(cc_schedule, 0.123456),
          (-0.123456, 0.123456))  # Not Quantized.

      # Step 5: Clipping active but not yet quantizing.
      sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 5))
      self.assertAllClose(
          self._ClipExample(cc_schedule, 100.0),
          (-6.0, 5.953125))  # 6 * 127/128
      self.assertAllClose(
          self._ClipExample(cc_schedule, 0.123456),
          (-0.123456, 0.123456))  # Not Quantized.

      # Step 7: Middle of clipping range.
      sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 7))
      self.assertAllClose(
          self._ClipExample(cc_schedule, 100.0), (-4.0, 3.96875))  # 4 * 127/128
      self.assertAllClose(
          self._ClipExample(cc_schedule, 0.123456),
          (-0.123456, 0.123456))  # Not Quantized.

      # Step 10: End of clipping range.
      sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 10))
      self.assertAllClose(
          self._ClipExample(cc_schedule, 100.0),
          (-1.0, 0.9921875))  # 1 * 127/128
      self.assertAllClose(
          self._ClipExample(cc_schedule, 0.123456),
          (-0.123456, 0.123456))  # Not Quantized.

      # Step 11: No more clipping but not yet quantizing.
      sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 11))
      self.assertAllClose(
          self._ClipExample(cc_schedule, 100.0),
          (-1.0, 0.9921875))  # 1 * 127/128
      self.assertAllClose(
          self._ClipExample(cc_schedule, 0.123456),
          (-0.123456, 0.123456))  # Not Quantized.

      # Step 15-16: Quantizing at full clip.
      for step in (15, 16):
        sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), step))
        self.assertAllClose(
            self._ClipExample(cc_schedule, 100.0),
            (-1.0, 0.9921875))  # 1 * 127/128
        self.assertAllClose(
            self._ClipExample(cc_schedule, 0.123456),
            (-0.125, 0.125))  # Quantized.
コード例 #9
0
ファイル: tpu_embedding_layers.py プロジェクト: Mddct/lingvo
  def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
    p = self.params

    load_op_list = []
    retrieve_op_list = []

    num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
    table_name = tpu_embedding_table.table_name
    slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars']

    for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
      # The slot vars should be on the same device as the table var.
      device_name = tpu_embedding_table.GetDeviceName(host_id)
      with tf.device(device_name), py_utils.outside_all_rewrites():
        accumulator = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(p.initial_accumulator_value),
            dtype=p.dtype,
            collections=slot_var_collections)
        accumulator_name = (
            tpu_embedding_table.GetVariableName(host_id) + '/Ftrl')
        tpu_embedding_table.CreateVariable(
            accumulator_name, accumulator, trainable=False)
        accumulator_var = tpu_embedding_table.vars[accumulator_name]

        linear = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(p.initial_linear_value),
            dtype=p.dtype,
            collections=slot_var_collections)
        linear_name = tpu_embedding_table.GetVariableName(host_id) + '/Ftrl_1'
        tpu_embedding_table.CreateVariable(linear_name, linear, trainable=False)
        linear_var = tpu_embedding_table.vars[linear_name]

        # Only the Trainer needs these ops.
        if py_utils.use_tpu():
          # Remove the slot vars from the variable list to avoid them being
          # copied to TPU.
          _RemovePrivateVar(tpu_embedding_table, accumulator_name)
          _RemovePrivateVar(tpu_embedding_table, linear_name)

          # TPU Embedding load/retrieve ops need to be in the outer graph scope.
          with tf.init_scope():
            tf.logging.info('creating load and retrieve ops.')
            load_parameters_op = (
                tpu_embedding_lib.tpu_ops.load_tpu_embedding_ftrl_parameters(
                    parameters=table_var,
                    accumulators=accumulator_var,
                    linears=linear_var,
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            load_op_list.append(load_parameters_op)

            retrieved_table, retrieved_accumulator, retrieved_linear = (
                tpu_embedding_lib.tpu_ops
                .retrieve_tpu_embedding_ftrl_parameters(
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                tf.assign(table_var, retrieved_table),
                tf.assign(accumulator_var, retrieved_accumulator),
                tf.assign(linear_var, retrieved_linear))
            retrieve_op_list.append(retrieve_parameters_op)

    return load_op_list, retrieve_op_list
コード例 #10
0
ファイル: tpu_embedding_layers.py プロジェクト: Mddct/lingvo
  def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
    p = self.params

    load_op_list = []
    retrieve_op_list = []

    num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
    table_name = tpu_embedding_table.table_name
    slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars']

    for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
      # The slot vars should be on the same device as the table var.
      device_name = tpu_embedding_table.GetDeviceName(host_id)
      with tf.device(device_name), py_utils.outside_all_rewrites():
        m_adam = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(0.0),
            dtype=p.dtype,
            collections=slot_var_collections)
        var_name_m = tpu_embedding_table.GetVariableName(host_id) + '/Adam/m'
        tpu_embedding_table.CreateVariable(var_name_m, m_adam, trainable=False)
        m_var = tpu_embedding_table.vars[var_name_m]

        v_adam = py_utils.WeightParams(
            shape=table_var.shape.as_list(),
            init=py_utils.WeightInit.Constant(0.0),
            dtype=p.dtype,
            collections=slot_var_collections)
        var_name_v = tpu_embedding_table.GetVariableName(host_id) + '/Adam/v'
        tpu_embedding_table.CreateVariable(var_name_v, v_adam, trainable=False)
        v_var = tpu_embedding_table.vars[var_name_v]

        # Only the Trainer needs these ops.
        if py_utils.use_tpu():
          # Remove the slot vars from the variable list to avoid them being
          # copied to TPU.
          _RemovePrivateVar(tpu_embedding_table, var_name_m)
          _RemovePrivateVar(tpu_embedding_table, var_name_v)

          # TPU Embedding load/retrieve ops need to be in the outer graph scope.
          with tf.init_scope():
            tf.logging.info('creating load and retrieve ops.')
            load_parameters_op = (
                tpu_embedding_lib.tpu_ops.load_tpu_embedding_adam_parameters(
                    parameters=table_var,
                    momenta=m_var,
                    velocities=v_var,
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            load_op_list.append(load_parameters_op)

            retrieved_table, retrieved_m, retrieved_v = (
                tpu_embedding_lib.tpu_ops
                .retrieve_tpu_embedding_adam_parameters(
                    table_name=table_name,
                    num_shards=num_tpu_hosts,
                    shard_id=host_id))
            retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                tf.assign(table_var, retrieved_table),
                tf.assign(m_var, retrieved_m), tf.assign(v_var, retrieved_v))
            retrieve_op_list.append(retrieve_parameters_op)

    return load_op_list, retrieve_op_list