def testLayerStack(self): model_dim = 4 num_heads = 2 d_kv = 2 d_ff = 8 builder = gshard_builder.DenseBuilder.Params().Set( dtype=tf.float32, relative_attention_type='bias', model_dim=model_dim, attention_num_heads=num_heads, attention_combine_dims=True, attention_num_memory_heads=1, model_dim_reshape_segments=2, ff_dim=d_ff, attention_key_value_dim=d_kv).Instantiate() def _GetInputs(): x = tf.constant([[[.1, .2, .3, .4], [.3, .4, .5, .6], [.5, .6, .1, .2]], [[.7, .8, .4, .5], [.9, .1, .2, .3], [.0, .9, .3, .7]]], dtype=tf.float32) seg_id = tf.constant([[1, 1, 1], [1, 1, 1]], dtype=tf.int32) pos_id = tf.constant([[0, 1, 2], [0, 1, 2]], dtype=tf.int32) # Reshape with model_dim_reshape_segments = 2 reshaped_x = tf.reshape(x, [2, 3, 2, -1]) return reshaped_x, seg_id, pos_id def _GetOutputs(enc, dec): x, seg_id, pos_id = _GetInputs() enc_inputs = py_utils.NestedMap( vec=x, segment_id=seg_id, segment_pos=pos_id, aux_loss=tf.constant(0.0)) enc_outs = enc.FPropDefaultTheta(enc_inputs) dec_inputs = py_utils.NestedMap( vec=x, segment_id=seg_id, segment_pos=pos_id, encoder_output=enc_outs.vec, encoder_segment_id=tf.zeros_like(seg_id), encoder_segment_pos=tf.zeros_like(pos_id), aux_loss=enc_outs.aux_loss) return dec.FPropDefaultTheta(dec_inputs).vec # Build a graph with RepeatLayer. g = tf.Graph() with g.as_default(): tf.random.set_seed(None) enc = builder.EncoderLayerStack( 'encoder', sub_layers=[builder.DenseReluDense('ffw')], num=2, use_repeat_layer=True).Instantiate() dec = builder.DecoderLayerStack( 'decoder', sub_layers=[builder.DenseReluDense('ffw', decoder=True)], num=2, use_repeat_layer=True).Instantiate() rep_out = _GetOutputs(enc, dec) tf.Session.reset(target='') with tf.Session(graph=g) as sess: sess.run(tf.global_variables_initializer()) rep_out = rep_out.eval(session=sess) var_values = sess.run(tf.trainable_variables()) # Build a graph without RepeatLayer. g = tf.Graph() with g.as_default(): tf.random.set_seed(None) enc = builder.EncoderLayerStack( 'encoder', sub_layers=[builder.DenseReluDense('ffw')], num=2).Instantiate() dec = builder.DecoderLayerStack( 'decoder', sub_layers=[builder.DenseReluDense('ffw', decoder=True)], num=2).Instantiate() dec_out = _GetOutputs(enc, dec) tf.Session.reset(target='') with tf.Session(graph=g) as sess: tf_vars = [ enc.vars.layer_000.ln.w.scale, enc.vars.layer_000.ffw.w.wi, enc.vars.layer_000.ffw.w.wo, enc.vars.layer_001.ln.w.scale, enc.vars.layer_001.ffw.w.wi, enc.vars.layer_001.ffw.w.wo, enc.vars.final_layer_norm.w.scale, dec.vars.layer_000.ln.w.scale, dec.vars.layer_000.ffw.w.wi, dec.vars.layer_000.ffw.w.wo, dec.vars.layer_001.ln.w.scale, dec.vars.layer_001.ffw.w.wi, dec.vars.layer_001.ffw.w.wo, dec.vars.final_layer_norm.w.scale ] for val, var in zip(var_values, tf_vars): sess.run(tf.assign(var, val)) dec_out = dec_out.eval(session=sess) self.assertAllClose(dec_out, rep_out)
def testParallelDecSelfAttentionRelativeBiasFFN(self): model_dim = 4 num_heads = 2 d_kv = 2 d_ff = 8 builder = gshard_builder.DenseBuilder.Params().Set( dtype=tf.float32, relative_attention_type='bias', model_dim=model_dim, attention_num_heads=num_heads, attention_combine_dims=True, attention_num_memory_heads=1, model_dim_reshape_segments=2, ff_dim=d_ff, attention_key_value_dim=d_kv).Instantiate() def _GetInputs(): x = tf.constant([[[.1, .2, .3, .4], [.3, .4, .5, .6], [.5, .6, .1, .2]], [[.7, .8, .4, .5], [.9, .1, .2, .3], [.0, .9, .3, .7]]], dtype=tf.float32) seg_id = tf.constant([[1, 1, 1], [1, 1, 1]], dtype=tf.int32) pos_id = tf.constant([[0, 1, 2], [0, 1, 2]], dtype=tf.int32) # Reshape with model_dim_reshape_segments = 2 reshaped_x = tf.reshape(x, [2, 3, 2, -1]) return reshaped_x, seg_id, pos_id # Build a graph with separate attention and ffn layers. # Naively compute the output by adding the outputs of the two directly. g = tf.Graph() with g.as_default(): tf.random.set_seed(None) x, seg_id, pos_id = _GetInputs() atten = builder.DecSelfAttentionRelativeBias('atten').Instantiate() ffn = builder.DenseReluDenseGated('ffn', tf.nn.relu, True).Instantiate() y_atten, _ = atten.FPropDefaultTheta(x, seg_id, pos_id, tf.constant(0), tf.constant(0), tf.constant(0)) y_ffn, _ = ffn.FPropDefaultTheta(x, seg_id, pos_id, tf.constant(0), tf.constant(0), tf.constant(0)) y_exp = (y_atten + y_ffn) * (2.0**-0.5) tf.Session.reset(target='') with tf.Session(graph=g) as sess: sess.run(tf.global_variables_initializer()) y_exp = y_exp.eval(session=sess) var_values = sess.run(tf.trainable_variables()) # Build a graph with dedeciated parallel layer and load the variable values. # Expect output the same as the previous naive implementation. g = tf.Graph() with g.as_default(): x, seg_id, pos_id = _GetInputs() parallel = builder.ParallelDecSelfAttentionRelativeBiasFFN( 'parallel', tf.nn.relu, hidden_dim_reshape_segments=2).Instantiate() y_parallel, _ = parallel.FPropDefaultTheta(x, seg_id, pos_id, tf.constant(0), tf.constant(0), tf.constant(0)) tf.Session.reset(target='') with tf.Session(graph=g) as sess: tf_vars = [ parallel.vars.w_atten.wq, parallel.vars.w_atten.wk, parallel.vars.w_atten.wv, parallel.vars.w_atten.wo, parallel.vars.wrb.wrb, parallel.vars.w_fflayer.wi_0, parallel.vars.w_fflayer.wi_1, parallel.vars.w_fflayer.wo ] for val, var in zip(var_values, tf_vars): sess.run(tf.assign(var, val)) y_parallel = y_parallel.eval(session=sess) self.assertAllClose(y_exp, y_parallel)
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [ tpu_embedding_table.__class__.__name__ + '_vars' ] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): w_ada = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_accumulator), dtype=p.dtype, collections=slot_var_collections) var_name = tpu_embedding_table.GetVariableName( host_id) + '/Adagrad' tpu_embedding_table.CreateVariable(var_name, w_ada, trainable=False) accumulator_var = tpu_embedding_table.vars[var_name] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to void copying them # to TPU (by the tf.cast in tpu_embedding_table.theta). # pylint: disable=protected-access del tpu_embedding_table._private_vars[var_name] del tpu_embedding_table._private_theta[var_name] # pylint: enable=protected-access # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops. load_tpu_embedding_adagrad_parameters( parameters=table_var, accumulators=accumulator_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_accumulator = ( tpu_embedding_lib.tpu_ops. retrieve_tpu_embedding_adagrad_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(accumulator_var, retrieved_accumulator)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def testBatchNormLayer(self): p = base_model.SingleTaskModel.Params() p.task = self.TestParams(layers.BatchNormLayer.Params().Set(dim=1)) p.task.train.ema_decay = 0.9 p.task.train.ema_decay_moving_vars = True model = p.Instantiate() self.assertIsNotNone(model.ema) task = model._task task._train_op = tf.no_op() task.ApplyExponentialMovingAverage(model.ema) layer = task.encoder self.assertLen(layer.vars, 4) for var in layer.vars.Flatten(): self.assertIsNotNone(model.ema.average(var), msg=var.name) beta = layer.vars.beta mean = layer.vars.moving_mean global_step = 100 beta_1 = np.asarray([.2]) mean_1 = np.asarray([.03]) beta_1_ema = beta_1 * .1 mean_1_ema = mean_1 * .1 with self.session() as sess: # Test EMA values. self.evaluate(tf.global_variables_initializer()) self.evaluate( tf.assign(py_utils.GetOrCreateGlobalStepVar(), global_step)) self.evaluate(tf.assign(beta, beta_1)) self.evaluate(tf.assign(mean, mean_1)) self.evaluate(task._post_train_ops) self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema], self.evaluate([ beta, model.ema.average(beta), mean, model.ema.average(mean) ])) # Test checkpointer. train_dir = os.path.join(self.get_temp_dir(), 'testSaveRestore') os.mkdir(train_dir) saver = checkpointer.Checkpointer(train_dir, model) saver.Save(sess, model.global_step) self.assertTrue( os.path.isfile( os.path.join(train_dir, 'ckpt-%08d.index' % global_step))) # Restore from ckpt in training mode. with self.session(graph=tf.Graph()) as sess: model = p.Instantiate() self.assertIsNotNone(model.ema) task = model._task task._train_op = tf.no_op() task.ApplyExponentialMovingAverage(model.ema) layer = task.encoder for var in layer.vars.Flatten(): self.assertIsNotNone(model.ema.average(var), msg=var.name) beta = layer.vars.beta mean = layer.vars.moving_mean saver = checkpointer.Checkpointer(train_dir, model) saver.RestoreIfNeeded(sess) self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema], self.evaluate([ beta, model.ema.average(beta), mean, model.ema.average(mean) ])) # Restore from ckpt in eval mode. with self.session(graph=tf.Graph()) as sess, self.SetEval(True): model = p.Instantiate() self.assertIsNotNone(model.ema) task = model._task # task._train_op = tf.no_op() # task.ApplyExponentialMovingAverage(model.ema) layer = task.encoder # for var in layer.vars.Flatten(): # self.assertIsNotNone(model.ema.average(var), msg=var.name) beta = layer.vars.beta mean = layer.vars.moving_mean saver = checkpointer.Checkpointer(train_dir, model) saver.RestoreIfNeeded(sess) # Both beta and mean should use the EMA value. self.assertAllClose([beta_1_ema, mean_1_ema], self.evaluate([beta, mean]))
def testLayerStackSummary(self): # In this test we very that summaries created inside stack layers # are processed properly with and without RepeatedLayer model_dim = 4 num_heads = 2 d_kv = 2 d_ff = 8 num_experts = 2 builder = gshard_builder.DenseBuilder.Params().Set( deterministic_dropout=True, dtype=tf.float32, relative_attention_type='bias', model_dim=model_dim, attention_num_heads=num_heads, attention_combine_dims=True, attention_num_memory_heads=1, model_dim_reshape_segments=None, ff_dim=d_ff, moe_hidden_dim=d_ff, e_dim=num_experts, c_dim=1, num_groups=num_experts, num_devices=num_experts, attention_key_value_dim=d_kv).Instantiate() def _GetOutputs(enc, dec): x, seg_id, pos_id = self._GetInputs() enc_inputs = py_utils.NestedMap(vec=x, segment_id=seg_id, segment_pos=pos_id, aux_loss=tf.constant(0.0)) enc_outs = enc.FPropDefaultTheta(enc_inputs) dec_inputs = py_utils.NestedMap( vec=x, segment_id=seg_id, segment_pos=pos_id, encoder_output=enc_outs.vec, encoder_segment_id=tf.zeros_like(seg_id), encoder_segment_pos=tf.zeros_like(pos_id), aux_loss=enc_outs.aux_loss) return dec.FPropDefaultTheta(dec_inputs).vec # Build a graph with RepeatLayer unrolled. g = tf.Graph() with g.as_default(), tpu_summary.context(), cluster_factory.SetEval( mode=True): tf.random.set_seed(None) enc = builder.EncoderLayerStack( 'encoder', sub_layers=[builder.DenseReluDense('ffw')], num=2, use_repeat_layer=True).Instantiate() dec = builder.DecoderLayerStack( 'decoder', sub_layers=[builder.MoE('moe', decoder=True)], num=2, use_repeat_layer=True).Instantiate() rep_unroll_out = _GetOutputs(enc, dec) rep_unroll_summary = tpu_summary.merge_all() expected_rep_unroll_summary = [ 'index_1/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating', 'index_1/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating', 'over_capacity_1/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity', 'over_capacity_1/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity', 'over_capacity_1_ratio/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity', 'over_capacity_1_ratio/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity', 'over_capacity_2/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity_1', 'over_capacity_2/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity_1', 'over_capacity_2_ratio/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity_1', 'over_capacity_2_ratio/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity_1', 'top1_expert/decoder_3/blocks/blocks_body/layer_000/moe/ffw/compute_gating', 'top1_expert/decoder_3/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating' ] self.assertEqual(set(expected_rep_unroll_summary), set(rep_unroll_summary)) tf.Session.reset(target='') with tf.Session(graph=g) as sess: sess.run(tf.global_variables_initializer()) rep_unroll_out, rep_unroll_summary = sess.run( [rep_unroll_out, rep_unroll_summary]) var_values = sess.run(tf.trainable_variables()) # Build a graph without RepeatLayer. g = tf.Graph() with g.as_default(), tpu_summary.context(): tf.random.set_seed(None) enc = builder.EncoderLayerStack('encoder', sub_layers=[ builder.DenseReluDense('ffw') ], num=2).Instantiate() dec = builder.DecoderLayerStack( 'decoder', sub_layers=[builder.MoE('moe', decoder=True)], num=2).Instantiate() dec_out = _GetOutputs(enc, dec) dec_summary = tpu_summary.merge_all() expected_dec_summary = [ 'index_1/decoder_1/layer_000/moe/ffw/compute_gating', 'index_1/decoder_1/layer_001/moe/ffw/compute_gating', 'over_capacity_1/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity', 'over_capacity_1/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity', 'over_capacity_1_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity', 'over_capacity_1_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity', 'over_capacity_2/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity_1', 'over_capacity_2/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity_1', 'over_capacity_2_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity_1', 'over_capacity_2_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity_1', 'top1_expert/decoder_1/layer_000/moe/ffw/compute_gating', 'top1_expert/decoder_1/layer_001/moe/ffw/compute_gating' ] self.assertCountEqual(expected_dec_summary, dec_summary) tf.Session.reset(target='') with tf.Session(graph=g) as sess: tf_vars = [ enc.vars.layer_000.ln.w.scale, enc.vars.layer_000.ffw.w.wi, enc.vars.layer_000.ffw.w.wo, enc.vars.layer_001.ln.w.scale, enc.vars.layer_001.ffw.w.wi, enc.vars.layer_001.ffw.w.wo, enc.vars.final_layer_norm.w.scale, dec.vars.layer_000.ln.w.scale, dec.vars.layer_000.moe.moe.wi, dec.vars.layer_000.moe.moe.wo, dec.vars.layer_000.moe.ffw.top_2_gating.w, dec.vars.layer_001.ln.w.scale, dec.vars.layer_001.moe.moe.wi, dec.vars.layer_001.moe.moe.wo, dec.vars.layer_001.moe.ffw.top_2_gating.w, dec.vars.final_layer_norm.w.scale ] for val, var in zip(var_values, tf_vars): sess.run(tf.assign(var, val)) dec_out, dec_summary = sess.run([dec_out, dec_summary]) self.assertAllClose(dec_out, rep_unroll_out) for name, alt_name in zip(expected_dec_summary, expected_rep_unroll_summary): self.assertAllClose(dec_summary[name], rep_unroll_summary[alt_name])
def variable_assign(var, new_value): return tf.assign(var, new_value, name=var.op.name + '_assign')
def PostTrainingStepUpdate(self): # We expect the training step to be done, so capture # the value of counter1 into counter2. return tf.assign(self.vars.counter2, self.vars.counter1)
def testFakeQuantizationScheduleTraining(self): p = quant_utils.FakeQuantizationSchedule.Params() p.clip_start_step = 5 p.clip_end_step = 10 p.quant_start_step = 15 p.start_cap = 6.0 p.end_cap = 1.0 with self.session() as sess: cc_schedule = p.Instantiate() tf.global_variables_initializer().run() # Step 0: No clipping. self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-100.0, 100.0)) self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 5: Clipping active but not yet quantizing. sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 5)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-6.0, 5.953125)) # 6 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 7: Middle of clipping range. sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 7)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-4.0, 3.96875)) # 4 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 10: End of clipping range. sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 10)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-1.0, 0.9921875)) # 1 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 11: No more clipping but not yet quantizing. sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 11)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-1.0, 0.9921875)) # 1 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 15-16: Quantizing at full clip. for step in (15, 16): sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), step)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-1.0, 0.9921875)) # 1 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.125, 0.125)) # Quantized.
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): accumulator = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_accumulator_value), dtype=p.dtype, collections=slot_var_collections) accumulator_name = ( tpu_embedding_table.GetVariableName(host_id) + '/Ftrl') tpu_embedding_table.CreateVariable( accumulator_name, accumulator, trainable=False) accumulator_var = tpu_embedding_table.vars[accumulator_name] linear = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_linear_value), dtype=p.dtype, collections=slot_var_collections) linear_name = tpu_embedding_table.GetVariableName(host_id) + '/Ftrl_1' tpu_embedding_table.CreateVariable(linear_name, linear, trainable=False) linear_var = tpu_embedding_table.vars[linear_name] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to avoid them being # copied to TPU. _RemovePrivateVar(tpu_embedding_table, accumulator_name) _RemovePrivateVar(tpu_embedding_table, linear_name) # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_ftrl_parameters( parameters=table_var, accumulators=accumulator_var, linears=linear_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_accumulator, retrieved_linear = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_ftrl_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(accumulator_var, retrieved_accumulator), tf.assign(linear_var, retrieved_linear)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): m_adam = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=slot_var_collections) var_name_m = tpu_embedding_table.GetVariableName(host_id) + '/Adam/m' tpu_embedding_table.CreateVariable(var_name_m, m_adam, trainable=False) m_var = tpu_embedding_table.vars[var_name_m] v_adam = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=slot_var_collections) var_name_v = tpu_embedding_table.GetVariableName(host_id) + '/Adam/v' tpu_embedding_table.CreateVariable(var_name_v, v_adam, trainable=False) v_var = tpu_embedding_table.vars[var_name_v] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to avoid them being # copied to TPU. _RemovePrivateVar(tpu_embedding_table, var_name_m) _RemovePrivateVar(tpu_embedding_table, var_name_v) # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_adam_parameters( parameters=table_var, momenta=m_var, velocities=v_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_m, retrieved_v = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_adam_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(m_var, retrieved_m), tf.assign(v_var, retrieved_v)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list