def pnum_tensor(self): if self._pnum_tensor is not None: return self._pnum_tensor with utils.outside_all_rewrites(): tf.logging.info("Create pnum_tensor") self._pnum_tensor = tpu_ops.tpu_replicated_input( self._physical_to_logical, name="pnum_constants") return self._pnum_tensor
def __init__(self, variable, mesh_impl): """Create a LaidOutVariable. Args: variable: a Variable (Operation) mesh_impl: a MeshImpl """ self._variable = variable self._mesh_impl = mesh_impl shape = variable.outputs[0].shape dtype = variable.outputs[0].dtype slice_shape = mesh_impl.slice_shape(shape) base_name = variable.name slices = [] for pnum in xrange(mesh_impl.size): slice_var_name = base_name + "_slice_%d" % pnum tpu_device = mesh_impl.device_assignment.tpu_device(replica=pnum) # The initializer is unimportant, since the slice variables will be # overwritten. zeros_initializer() is here to avoid the default # initialization which adds lots of useless operations to the TF graph. with ops.device(tpu_device): slices.append( tf.get_variable( slice_var_name, slice_shape, dtype=dtype, collections=[], initializer=tf.zeros_initializer())) self._laid_out_tensor = mesh_impl.LaidOutTensor( [tpu_variables.ReplicatedVariable(base_name, slices)]) with tf.device(variable.master.device), utils.outside_all_rewrites(): self._copy_master_to_slices = self.assign_to_slices( mesh_impl.make_slices(variable.master, shape), assign_to_tensor_list=slices) self._copy_slices_to_master = tf.assign( variable.master, mesh_impl.combine_slices(slices, shape, device=variable.master.device))
def __init__(self, variable, mesh_impl): """Create a LaidOutVariable. Args: variable: a Variable (Operation) mesh_impl: a MeshImpl """ self._variable = variable self._mesh_impl = mesh_impl shape = variable.outputs[0].shape slice_shape = mesh_impl.slice_shape(shape) base_name = variable.name slices = [] slices_with_master_dtype = [] with tf.device( variable.master_device), utils.outside_all_rewrites(): zero_tensor = tf.zeros(slice_shape, dtype=variable.slice_dtype) # pylint: disable=protected-access init_device_stack = tf.get_default_graph()._device_function_stack if not mesh_impl.graph_device_function_stacks: for pnum in xrange(mesh_impl.size): tpu_device = mesh_impl.device_assignment.tpu_device( replica=pnum) with tf.device(tpu_device): mesh_impl.graph_device_function_stacks.append( tf.get_default_graph()._device_function_stack.copy( )) for physical_pnum in xrange(mesh_impl.size): slice_var_name = base_name + "_slice_%d" % physical_pnum # Use tf.Variable instead of tf.get_variable since latter adds lots of # useless operations to the TF graph. # Note: Repeatedly 'with tf.device():' slows down the graph # construction. Therefore we directly use the cached device_stack here. tf.get_default_graph()._device_function_stack = ( mesh_impl.graph_device_function_stacks[physical_pnum]) slices.append( tf.Variable(initial_value=zero_tensor, trainable=True, collections=[], dtype=variable.slice_dtype, name=slice_var_name, expected_shape=slice_shape)) # Restore the initial stack tf.get_default_graph()._device_function_stack = init_device_stack # pylint: enable=protected-access self._laid_out_tensor = mesh_impl.LaidOutTensor( [tpu_variables.ReplicatedVariable(base_name, slices)]) with tf.device( variable.master_device), utils.outside_all_rewrites(): if os.environ.get("MTF_SEQUENCE_MODE", "") == "1": if mesh_impl.copy_master_to_slice_ops: with tf.control_dependencies( [mesh_impl.copy_master_to_slice_ops[-1]]): self._copy_master_to_slices = self._gen_copy_master_to_slices_op( variable.get_master(), shape, slices, slice_shape) else: self._copy_master_to_slices = self._gen_copy_master_to_slices_op( variable.get_master(), shape, slices, slice_shape) mesh_impl.copy_master_to_slice_ops.append( self._copy_master_to_slices) else: self._copy_master_to_slices = self._gen_copy_master_to_slices_op( variable.get_master(), shape, slices, slice_shape) slices_with_master_dtype = [ tf.cast(s, variable.master_dtype) for s in slices ] slices_with_master_dtype = [ slices_with_master_dtype[mesh_impl.l2p(logical_pnum)] for logical_pnum in range(mesh_impl.size) ] self._copy_slices_to_master = variable.assign_to_master( mesh_impl.combine_slices(slices_with_master_dtype, shape, device=variable.master_device))