コード例 #1
0
 def pnum_tensor(self):
     if self._pnum_tensor is not None:
         return self._pnum_tensor
     with utils.outside_all_rewrites():
         tf.logging.info("Create pnum_tensor")
         self._pnum_tensor = tpu_ops.tpu_replicated_input(
             self._physical_to_logical, name="pnum_constants")
         return self._pnum_tensor
コード例 #2
0
ファイル: simd_mesh_impl.py プロジェクト: tspannhw/mesh
    def __init__(self, variable, mesh_impl):
      """Create a LaidOutVariable.

      Args:
        variable: a Variable (Operation)
        mesh_impl: a MeshImpl
      """
      self._variable = variable
      self._mesh_impl = mesh_impl
      shape = variable.outputs[0].shape
      dtype = variable.outputs[0].dtype
      slice_shape = mesh_impl.slice_shape(shape)
      base_name = variable.name
      slices = []
      for pnum in xrange(mesh_impl.size):
        slice_var_name = base_name + "_slice_%d" % pnum
        tpu_device = mesh_impl.device_assignment.tpu_device(replica=pnum)
        # The initializer is unimportant, since the slice variables will be
        # overwritten.  zeros_initializer() is here to avoid the default
        # initialization which adds lots of useless operations to the TF graph.
        with ops.device(tpu_device):
          slices.append(
              tf.get_variable(
                  slice_var_name,
                  slice_shape,
                  dtype=dtype,
                  collections=[],
                  initializer=tf.zeros_initializer()))
      self._laid_out_tensor = mesh_impl.LaidOutTensor(
          [tpu_variables.ReplicatedVariable(base_name, slices)])
      with tf.device(variable.master.device), utils.outside_all_rewrites():
        self._copy_master_to_slices = self.assign_to_slices(
            mesh_impl.make_slices(variable.master, shape),
            assign_to_tensor_list=slices)
        self._copy_slices_to_master = tf.assign(
            variable.master,
            mesh_impl.combine_slices(slices, shape,
                                     device=variable.master.device))
コード例 #3
0
        def __init__(self, variable, mesh_impl):
            """Create a LaidOutVariable.

      Args:
        variable: a Variable (Operation)
        mesh_impl: a MeshImpl
      """
            self._variable = variable
            self._mesh_impl = mesh_impl
            shape = variable.outputs[0].shape
            slice_shape = mesh_impl.slice_shape(shape)
            base_name = variable.name
            slices = []
            slices_with_master_dtype = []
            with tf.device(
                    variable.master_device), utils.outside_all_rewrites():
                zero_tensor = tf.zeros(slice_shape, dtype=variable.slice_dtype)

            # pylint: disable=protected-access
            init_device_stack = tf.get_default_graph()._device_function_stack

            if not mesh_impl.graph_device_function_stacks:
                for pnum in xrange(mesh_impl.size):
                    tpu_device = mesh_impl.device_assignment.tpu_device(
                        replica=pnum)
                    with tf.device(tpu_device):
                        mesh_impl.graph_device_function_stacks.append(
                            tf.get_default_graph()._device_function_stack.copy(
                            ))

            for physical_pnum in xrange(mesh_impl.size):
                slice_var_name = base_name + "_slice_%d" % physical_pnum
                # Use tf.Variable instead of tf.get_variable since latter adds lots of
                # useless operations to the TF graph.
                # Note: Repeatedly 'with tf.device():' slows down the graph
                # construction. Therefore we directly use the cached device_stack here.
                tf.get_default_graph()._device_function_stack = (
                    mesh_impl.graph_device_function_stacks[physical_pnum])

                slices.append(
                    tf.Variable(initial_value=zero_tensor,
                                trainable=True,
                                collections=[],
                                dtype=variable.slice_dtype,
                                name=slice_var_name,
                                expected_shape=slice_shape))

            # Restore the initial stack
            tf.get_default_graph()._device_function_stack = init_device_stack
            # pylint: enable=protected-access

            self._laid_out_tensor = mesh_impl.LaidOutTensor(
                [tpu_variables.ReplicatedVariable(base_name, slices)])
            with tf.device(
                    variable.master_device), utils.outside_all_rewrites():
                if os.environ.get("MTF_SEQUENCE_MODE", "") == "1":
                    if mesh_impl.copy_master_to_slice_ops:
                        with tf.control_dependencies(
                            [mesh_impl.copy_master_to_slice_ops[-1]]):
                            self._copy_master_to_slices = self._gen_copy_master_to_slices_op(
                                variable.get_master(), shape, slices,
                                slice_shape)
                    else:
                        self._copy_master_to_slices = self._gen_copy_master_to_slices_op(
                            variable.get_master(), shape, slices, slice_shape)

                    mesh_impl.copy_master_to_slice_ops.append(
                        self._copy_master_to_slices)
                else:
                    self._copy_master_to_slices = self._gen_copy_master_to_slices_op(
                        variable.get_master(), shape, slices, slice_shape)
                slices_with_master_dtype = [
                    tf.cast(s, variable.master_dtype) for s in slices
                ]
                slices_with_master_dtype = [
                    slices_with_master_dtype[mesh_impl.l2p(logical_pnum)]
                    for logical_pnum in range(mesh_impl.size)
                ]
                self._copy_slices_to_master = variable.assign_to_master(
                    mesh_impl.combine_slices(slices_with_master_dtype,
                                             shape,
                                             device=variable.master_device))