def testXavier3D(self): with self.session(use_gpu=False, graph=tf.Graph()): tf.set_random_seed(1618) methods = [py_utils.WeightInit.Xavier] dtypes = [tf.float32, tf.float16, tf.complex64] shapes = [[1, 1, 2]] all_vars = [] for i, (m, dt, sp) in enumerate( itertools.product(methods, dtypes, shapes)): pc = py_utils.WeightParams(sp, m(), dt) all_vars.append(py_utils.CreateVariable('var_%d' % i, pc)[0]) v1_v_expted = [[[1.357139, -1.23832]]] tf.global_variables_initializer().run() v1_v = all_vars[0].eval() self.assertAllClose(v1_v_expted, v1_v.tolist())
def _CreateVariables(self): super()._CreateVariables() p = self.params collections = [ self.__class__.__name__ + '_vars', py_utils.SKIP_LP_REGULARIZATION ] pc = py_utils.WeightParams( shape=[1, 1, 1, p.dim], init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=collections) self.CreateVariable('beta', pc) # Note, The real gamma to use is 1 + gamma. self.CreateVariable('gamma', pc, lambda x: 1.0 + x)
def _CreateLayerVariables(self): super()._CreateLayerVariables() p = self.params assert p.input_rank == 3 or p.input_rank == 4 collections = [ self.__class__.__name__ + '_vars', py_utils.SKIP_LP_REGULARIZATION ] shape = [1, 1, 1, p.dim] if p.input_rank == 4 else [1, 1, p.dim] pc = py_utils.WeightParams(shape=shape, init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=collections) self.CreateVariable('beta', pc) self.CreateVariable('gamma', pc)
def __init__(self, params): super(DeterministicWeightsLayer, self).__init__(params) p = self.params if not p.name: raise ValueError('Layer must have a specified name!') assert p.num_sources > 0, ('Must specify num_sources > 0.') params_init = py_utils.WeightInit.Constant(0.0) # Weights to be learned. pw = py_utils.WeightParams( shape=[p.num_sources], init=params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) with tf.variable_scope(p.name): self.CreateVariable('sum_weight', pw) p.dropout_tpl.name = 'dropout' self.CreateChild('weighted_merger_dropout', p.dropout_tpl)
def __init__(self, params): super(SoftCondLayer, self).__init__(params) p = self.params assert p.name assert p.num_experts assert p.input_dim with tf.variable_scope(p.name): # Create Variables for task weight mapping. w_p = py_utils.WeightParams( shape=[p.input_dim, p.num_experts], init=p.params_init, # TODO(huangyp): try zero init instead. dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) self.CreateVariable('w', w_p) # Prepends p.num_experts to the tensor shape of every variable created # by p.body. with py_utils.VariableShapePrefixContext(p.num_experts): self.CreateChild('body', p.body)
def testOpportunisticReuse(self): pc = py_utils.WeightParams([3, 3]) _, v1 = py_utils.CreateVariable('v1', pc) with self.assertRaises(Exception): _ = py_utils.CreateVariable('v1', pc) with py_utils.OpportunisticVariableReuseScope(True): _, v2 = py_utils.CreateVariable('v1', pc) _, x1 = py_utils.CreateVariable('x1', pc) with py_utils.OpportunisticVariableReuseScope(False): with self.assertRaises(Exception): _ = py_utils.CreateVariable('v1', pc) _, v3 = py_utils.CreateVariable('v1', pc) with self.assertRaises(Exception): _ = py_utils.CreateVariable('v1', pc) for v in [v2, v3]: self.assertTrue(v1 is v) self.assertTrue(v1 is not x1)
def CreateVariable(self, name, var_params, theta_fn=None, *args, **kwargs): """Create a variable of this layer according to the parameter `var_params`. E.g.:: def __init__(self, ...): # A layer's constructor self.CreateVariable( 'weight', py_utils.WeightParams(shape=[100, 100])) `theta_fn` is used to apply a simple transformation on the created variable's value before used by the forward computation. E.g., to add the global variational noise according to this layer's parameter, one can do:: def __init__(self, ...): # A layer's constructor self.CreateVariable( name='weight', var_params=py_utils.WeightParams(shape=[100, 100]), theta_fn=self.AddGlobalVN) Args: name: Variable name which is used as the key into vars/theta. var_params: `Params` used to create the variable. theta_fn: A python function that takes a variable's value and returns a new value to be used later for computation. Its signature must be (tf.Tensor) -> (tf.Tensor). *args: List of args passed to `.py_utils.CreateVariable`. **kwargs: Keyword args passed to `.py_utils.CreateVariable`. """ self._CheckName(name) if (self.params.skip_lp_regularization and py_utils.SKIP_LP_REGULARIZATION not in var_params.collections): var_params = py_utils.WeightParams( shape=var_params.shape, dtype=var_params.dtype, init=var_params.init, collections=(var_params.collections + [py_utils.SKIP_LP_REGULARIZATION])) self._var_symbolic_shape_map[name] = var_params.shape value, var = py_utils.CreateVariable(name, var_params, *args, **kwargs) self._private_vars[name] = var if theta_fn is not None: value = theta_fn(value) self._private_theta[name] = value
def __init__(self, params): super(MergerLayer, self).__init__(params) p = self.params if not p.name: raise ValueError('Layer must have a specified name!') if p.merger_op not in set(self.MERGER_OPS): raise ValueError('Merger op must be one of: ', self.MERGER_OPS) if p.merger_op == 'atten': atten_params = p.attention_tpl.Copy() atten_params.source_dim = p.source_dim atten_params.query_dim = p.query_dim atten_params.hidden_dim = p.hidden_dim atten_params.dtype = p.dtype if atten_params.params_init is None: atten_params.params_init = py_utils.WeightInit.Gaussian( 1. / math.sqrt(atten_params.source_dim + atten_params.query_dim)) self.CreateChild('atten', atten_params) if p.pre_proj_input_dims: if not p.pre_proj_output_dim: raise ValueError('Output dim should be specified for projection.') pre_proj_params = [] for i, pre_proj_dim in enumerate(p.pre_proj_input_dims): proj_p = p.proj_tpl.Copy() proj_p.name = 'merger_pre_proj_%d' % i proj_p.input_dim = pre_proj_dim proj_p.output_dim = p.pre_proj_output_dim pre_proj_params.append(proj_p) self.CreateChildren('pre_proj', pre_proj_params) if p.merger_op == 'weighted_sum': assert p.num_sources > 0, ('For merger_op=weighted_sum, must specify ' 'num_sources > 0.') params_init = py_utils.WeightInit.Constant(1.0 / p.num_sources) # Weights to be learned. pw = py_utils.WeightParams( shape=[p.num_sources], init=params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) with tf.variable_scope(p.name): _, self._sum_weight = py_utils.CreateVariable('sum_weight', pw)
def testScaleGradients(self): p = self.TestParams() p.input = base_input_generator.BaseSequenceInputGenerator.Params() task = p.Instantiate() task.CreateVariable( 'a', py_utils.WeightParams(shape=[], init=py_utils.WeightInit.Constant(0))) var_a = task.theta.a var_grads = py_utils.NestedMap(a=(var_a, tf.ones_like(var_a))) scaled_grads_map = task.learners[0].ScaleGradients(var_grads) FLAGS.enable_check_numerics = False with self.session(): tf.global_variables_initializer().run() self.assertEqual(1.0, scaled_grads_map.grad_scale.eval()) # The final gradient must be finite. self.assertFalse(tf.is_nan(scaled_grads_map.final_var_grads.a[1]).eval()) self.assertTrue( tf.is_finite(scaled_grads_map.final_var_grads.a[1]).eval())
def testScaleGradientsCheckNumerics(self): """ScaleGradients when enable_check_numerics=True.""" FLAGS.enable_check_numerics = True p = self.TestParams() p.input = base_input_generator.BaseSequenceInputGenerator.Params() task = p.Instantiate() task.CreateVariable( 'a', py_utils.WeightParams(shape=[], init=py_utils.WeightInit.Constant(0))) var_a = task.theta.a # Make a NaN gradient. var_grads = py_utils.NestedMap(a=py_utils.VarGrad(var_a, 0. * tf.log(0.))) scaled_grads_map = task.learners[0].ScaleGradients(var_grads) with self.session(): tf.global_variables_initializer().run() self.assertEqual(0., scaled_grads_map.grad_scale.eval()) # Fetching the gradient raises an exception with enable_check_numerics. with self.assertRaisesRegex(tf.errors.InvalidArgumentError, 'is not finite'): _ = scaled_grads_map.final_var_grads.a[1].eval()
def testScaleGradientsInf(self): FLAGS.enable_check_numerics = False p = self.TestParams() p.input = base_input_generator.BaseSequenceInputGenerator.Params() task = p.cls(p) task.CreateVariable( 'a', py_utils.WeightParams(shape=[], init=py_utils.WeightInit.Constant(0))) var_a = task.theta.a # Infinite gradient. var_grads = py_utils.NestedMap(a=(var_a, tf.log(0.))) has_nan_or_inf, grad_scale, final_var_grads = task.ScaleGradients( var_grads) with self.session(): tf.global_variables_initializer().run() self.assertTrue(has_nan_or_inf.eval()) self.assertEqual(0., grad_scale.eval()) # The final gradient must be finite. self.assertFalse(tf.is_nan(final_var_grads.a[1]).eval()) self.assertTrue(tf.is_finite(final_var_grads.a[1]).eval())
def _CreateLayerVariables(self): p = self.params w_pc = py_utils.WeightParams( shape=[self._ids_per_shard, p.embedding_dim], init=p.params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) embedding_table_vars = [] for i in range(p.num_tpu_hosts): device_name = self.GetDeviceName(i) with tf.device(device_name), py_utils.outside_all_rewrites(): var_name = self.GetVariableName(i) self.CreateVariable(var_name, w_pc) embedding_var = self.vars[var_name] embedding_table_vars.append(embedding_var) # Remove from _private_vars / _private_thetas to be added later as wm. del self._private_vars[var_name] del self._private_theta[var_name] self._tpu_embedding_collection.AddTableVariables( self.table_name, embedding_table_vars) if not py_utils.use_tpu(): # We don't want to add this for TrainerTpu, otherwise the identity # reference leads to copying the embedding to the TPU for no reason. # However, this is needed for CPU (eval/decode/controller). self._private_vars['wm'] = embedding_table_vars self._private_theta['wm'] = [ tf.identity(v) for v in embedding_table_vars ] # Only trainer and controller need slot variables and load/retrieve ops. if not self.do_eval: self._load_op_list, self._retrieve_op_list = ( self.optimizer.CreateSlotVariablesAndOps( embedding_table_vars, self))
def testScaleGradientsNaN(self): FLAGS.enable_check_numerics = False p = self.TestParams() p.input = base_input_generator.BaseSequenceInputGenerator.Params() task = p.Instantiate() task.CreateVariable( 'a', py_utils.WeightParams(shape=[], init=py_utils.WeightInit.Constant(0))) var_a = task.theta.a # Make a NaN gradient. var_grads = py_utils.NestedMap( a=py_utils.VarGrad(var_a, 0. * tf.math.log(0.))) scaled_grads_map = task.learners[0].ScaleGradients(var_grads) with self.session(): self.evaluate(tf.global_variables_initializer()) self.assertEqual(0., scaled_grads_map.grad_scale.eval()) # The final gradient must be finite. self.assertFalse( tf.math.is_nan(scaled_grads_map.final_var_grads.a[1]).eval()) self.assertTrue( tf.math.is_finite( scaled_grads_map.final_var_grads.a[1]).eval())
def testXavier(self): with self.session(use_gpu=False, graph=tf.Graph()): tf.set_random_seed(1618) methods = [py_utils.WeightInit.Xavier] dtypes = [tf.float32, tf.float16, tf.complex64] shapes = [[2, 3]] all_vars = [] for i, (m, dt, sp) in enumerate( itertools.product(methods, dtypes, shapes)): pc = py_utils.WeightParams(sp, m(), dt) all_vars.append(py_utils.CreateVariable('var_%d' % i, pc)[0]) v1_v_expted = [[1.051236, -0.959198, 0.796091], [-0.685691, 0.230933, -1.006293]] v3_v_expted = [ [0.149996 - 0.064369j, 0.689145 + 0.017257j, -0.502070 - 0.367683j], [0.519782 + 0.470412j, 0.738902 - 0.054006j, 0.028603 + 0.471832j], ] tf.global_variables_initializer().run() v1_v = all_vars[0].eval() v3_v = all_vars[2].eval() self.assertAllClose(v1_v_expted, v1_v.tolist()) self.assertAllClose(v3_v_expted, v3_v.tolist())
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): m_adam = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=slot_var_collections) var_name_m = tpu_embedding_table.GetVariableName(host_id) + '/Adam/m' tpu_embedding_table.CreateVariable(var_name_m, m_adam, trainable=False) m_var = tpu_embedding_table.vars[var_name_m] v_adam = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=slot_var_collections) var_name_v = tpu_embedding_table.GetVariableName(host_id) + '/Adam/v' tpu_embedding_table.CreateVariable(var_name_v, v_adam, trainable=False) v_var = tpu_embedding_table.vars[var_name_v] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to avoid them being # copied to TPU. _RemovePrivateVar(tpu_embedding_table, var_name_m) _RemovePrivateVar(tpu_embedding_table, var_name_v) # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_adam_parameters( parameters=table_var, momenta=m_var, velocities=v_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_m, retrieved_v = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_adam_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(m_var, retrieved_m), tf.assign(v_var, retrieved_v)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def __init__(self, params): super(MultiTaskModelTest._TestTaskWithVars, self).__init__(params) pc = py_utils.WeightParams(shape=[10, 10], dtype=tf.float32) self.CreateVariable('weight', pc)
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): accumulator = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_accumulator_value), dtype=p.dtype, collections=slot_var_collections) accumulator_name = ( tpu_embedding_table.GetVariableName(host_id) + '/Ftrl') tpu_embedding_table.CreateVariable( accumulator_name, accumulator, trainable=False) accumulator_var = tpu_embedding_table.vars[accumulator_name] linear = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_linear_value), dtype=p.dtype, collections=slot_var_collections) linear_name = tpu_embedding_table.GetVariableName(host_id) + '/Ftrl_1' tpu_embedding_table.CreateVariable(linear_name, linear, trainable=False) linear_var = tpu_embedding_table.vars[linear_name] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to avoid them being # copied to TPU. _RemovePrivateVar(tpu_embedding_table, accumulator_name) _RemovePrivateVar(tpu_embedding_table, linear_name) # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_ftrl_parameters( parameters=table_var, accumulators=accumulator_var, linears=linear_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_accumulator, retrieved_linear = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_ftrl_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(accumulator_var, retrieved_accumulator), tf.assign(linear_var, retrieved_linear)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def _CreateVariables(self): super()._CreateVariables() pc = py_utils.WeightParams(shape=[10, 10], dtype=tf.float32) self.CreateVariable('weight', pc)
def __init__(self, params): assert issubclass(params.cls, BaseTask) # Ensure global_step exists before calling super. py_utils.GetOrCreateGlobalStepVar() super().__init__(params) p = self.params self._encoder = None self._online_encoder = None self._decoder = None self._loss = None self._num_predictions = None self._train_op = None self._post_train_ops = [] self._eval_metrics = {} self._per_example = {} # Create the gradient mask, self._per_input_gradient_mask = None if p.task_global_step: with tf.name_scope(None), tf.variable_scope( py_utils.GetGlobalVariableScope()): var_name = p.name + '_global_step' # Create the variable immediately. self._CreateVariableInternal( var_name, base_layer.CreateVariableMeta( var_params=py_utils.WeightParams( [], py_utils.WeightInit.Constant(0), tf.int64), theta_fn=None, kwargs=dict( trainable=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES]))) summary_utils.scalar(var_name, self._private_vars[var_name]) self._global_step_var = self._private_vars[var_name] else: self._global_step_var = py_utils.GetOrCreateGlobalStepVar() if p.input: # TODO(zhifengc): Consider a simpler way to ensure the input # generator stops after one epoch. if self.do_eval and p.eval: seq_inp = issubclass(p.input.cls, base_input_generator.BaseInputGeneratorFromFiles) if p.input.num_samples == 0: # Dataset size is unknown. Computes eval summary based on num_samples. assert p.eval.samples_per_summary > 0 elif (p.eval.samples_per_summary == 0) or (p.input.num_samples < p.eval.samples_per_summary): # If we know the dataset size and we want to evaluate the full # set, we need to coordinate the input generator to flush out # all samples so the evaler and decoder compute metrics on the # whole set for each summary step. if seq_inp: p.input.flush_every_n = p.input.num_samples p.eval.samples_per_summary = p.input.num_samples if seq_inp and p.input.num_batcher_threads > 1: tf.logging.warning( 'input.num_batcher_threads > 1 inside eval mode. ' 'The input generator may not iterate over exactly ' 'one epoch per run') tf.logging.info('input_params: %s', p.input) input_params = self.cluster.PlaceInput(p.input) # For TPU training, we create the input generator in a # different scope and AddChild it in later. if 'skip_create_child' not in p.input: self.CreateChild('input', input_params) tp = p.train # p.train can be None if this task is the teacher/student task in a # DistillationTask. if tp: self._SetLearnerFromLegacyParams(tp) if tp.learner is not None: if isinstance(tp.learner, (list, tuple)): self.CreateChildren('learners', tp.learner) else: self.CreateChildren('learners', [tp.learner]) self._UpdateVnConfig()
def _CreateLayerVariables(self): super()._CreateLayerVariables() self.CreateVariable( 'x', py_utils.WeightParams(shape=[], init=py_utils.WeightInit.Uniform()))
def CreateVariable(self, name, var_params, theta_fn=None, **kwargs): """Create a variable of this layer according to the parameter `var_params`. E.g.:: def __init__(self, ...): # A layer's constructor self.CreateVariable( 'weight', py_utils.WeightParams(shape=[100, 100])) `theta_fn` is used to apply a simple transformation on the created variable's value before used by the forward computation. E.g., to add the global variational noise according to this layer's parameter, one can do:: def __init__(self, ...): # A layer's constructor self.CreateVariable( name='weight', var_params=py_utils.WeightParams(shape=[100, 100]), theta_fn=self.AddVN) In some contexts, eg. TPU training, variables may not be created immediately but rather the creation request will be cached and created later via a call to layer.InstantiateVariables(). Args: name: Variable name which is used as the key into vars/theta. var_params: `Params` used to create the variable. theta_fn: A python function that takes a variable's value and returns a new value to be used later for computation. Its signature must be (tf.Tensor) -> (tf.Tensor). **kwargs: Keyword args passed to `.py_utils.CreateVariable`. """ if self.params.device_mesh is not None: if (len([dim for dim in var_params.shape if dim > 1]) > 1 and var_params.tensor_split_dims_mapping is None): tf.logging.warning( 'tensor_split_dims_mapping missing for %s.%s: shape=%s', self.path, name, var_params.shape) if self._is_variable_free: raise ValueError('Cannot create variable in variable free layer.') if self._create_variables_status == _CreateLayerVariablesStatus.COMPLETED: raise ValueError( 'CreateVariable call after variable creation has completed! ' 'CreateVariable should be called in __init__ or _CreateLayerVariables.' ) self._CheckName(name) if (self.params.skip_lp_regularization and py_utils.SKIP_LP_REGULARIZATION not in var_params.collections): var_params = py_utils.WeightParams( shape=var_params.shape, dtype=var_params.dtype, init=var_params.init, collections=(var_params.collections + [py_utils.SKIP_LP_REGULARIZATION])) self._var_symbolic_shape_map[name] = var_params.shape meta = CreateVariableMeta(var_params=var_params.Copy(), theta_fn=theta_fn, kwargs=kwargs) if self._create_variables_status == _CreateLayerVariablesStatus.IN_PROGRESS: # If InstantiateVariables has been called, create variable immediately. self._CreateVariableInternal(name, meta) else: # Otherwise cache the variable to be created. self._variables_to_create[name] = meta
def get_weight_params(): return py_utils.WeightParams( shape=[1, p.batch_size, p.lm.emb.embedding_dim], init=py_utils.WeightInit.Constant(scale=np.zeros([p.batch_size, p.lm.emb.embedding_dim])), dtype=tf.float32, collections=[self.__class__.__name__ + '_vars'])
def _CreateVariables(self): super()._CreateVariables() self.CreateVariable( 'x', py_utils.WeightParams(shape=[], init=py_utils.WeightInit.Constant(0)))
def _CreateLayerVariables(self): super()._CreateLayerVariables() pc = py_utils.WeightParams(shape=[], init=py_utils.WeightInit.Constant(0), dtype=self.params.dtype) self.CreateVariable('ext', pc)
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [ tpu_embedding_table.__class__.__name__ + '_vars' ] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): w_ada = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_accumulator), dtype=p.dtype, collections=slot_var_collections) var_name = tpu_embedding_table.GetVariableName( host_id) + '/Adagrad' tpu_embedding_table.CreateVariable(var_name, w_ada, trainable=False) accumulator_var = tpu_embedding_table.vars[var_name] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to void copying them # to TPU (by the tf.cast in tpu_embedding_table.theta). # pylint: disable=protected-access del tpu_embedding_table._private_vars[var_name] del tpu_embedding_table._private_theta[var_name] # pylint: enable=protected-access # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops. load_tpu_embedding_adagrad_parameters( parameters=table_var, accumulators=accumulator_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_accumulator = ( tpu_embedding_lib.tpu_ops. retrieve_tpu_embedding_adagrad_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(accumulator_var, retrieved_accumulator)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list