def _calculate_mean_and_var(self, x, axes, keep_dims):

        with backend.name_scope('moments'):
            # The dynamic range of fp16 is too limited to support the collection of
            # sufficient statistics. As a workaround we simply perform the operations
            # on 32-bit floats before converting the mean and variance back to fp16
            y = math_ops.cast(
                x, dtypes.float32) if x.dtype == dtypes.float16 else x
            replica_ctx = ds.get_replica_context()
            if replica_ctx:
                local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
                local_squared_sum = math_ops.reduce_sum(math_ops.square(y),
                                                        axis=axes,
                                                        keepdims=True)
                batch_size = math_ops.cast(
                    array_ops.shape_v2(y)[axes[0]], dtypes.float32)
                # TODO(b/163099951): batch the all-reduces once we sort out the ordering
                # issue for NCCL. We don't have a mechanism to launch NCCL in the same
                # order in each replica nowadays, so we limit NCCL to batch all-reduces.
                y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
                                               local_sum)
                y_squared_sum = replica_ctx.all_reduce(
                    reduce_util.ReduceOp.SUM, local_squared_sum)
                global_batch_size = replica_ctx.all_reduce(
                    reduce_util.ReduceOp.SUM, batch_size)

                axes_vals = [(array_ops.shape_v2(y))[axes[i]]
                             for i in range(1, len(axes))]
                multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
                                           dtypes.float32)
                multiplier = multiplier * global_batch_size

                mean = y_sum / multiplier
                y_squared_mean = y_squared_sum / multiplier
                # var = E(x^2) - E(x)^2
                variance = y_squared_mean - math_ops.square(mean)
            else:
                # Compute true mean while keeping the dims for proper broadcasting.
                mean = math_ops.reduce_mean(y,
                                            axes,
                                            keepdims=True,
                                            name='mean')
                # sample variance, not unbiased variance
                # Note: stop_gradient does not change the gradient that gets
                #       backpropagated to the mean from the variance calculation,
                #       because that gradient is zero
                variance = math_ops.reduce_mean(math_ops.squared_difference(
                    y, array_ops.stop_gradient(mean)),
                                                axes,
                                                keepdims=True,
                                                name='variance')
            if not keep_dims:
                mean = array_ops.squeeze(mean, axes)
                variance = array_ops.squeeze(variance, axes)
            if x.dtype == dtypes.float16:
                return (math_ops.cast(mean, dtypes.float16),
                        math_ops.cast(variance, dtypes.float16))
            else:
                return (mean, variance)
    def _calculate_mean_and_var(self, x, axes, keep_dims):

        with ops.name_scope('moments', values=[x, axes]):
            # The dynamic range of fp16 is too limited to support the collection of
            # sufficient statistics. As a workaround we simply perform the operations
            # on 32-bit floats before converting the mean and variance back to fp16
            y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x

            if horovod_enabled():
                num_shards = hvd.size()
            else:
                num_shards = 1

            if num_shards > 1:
                local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
                local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True)
                batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32)
                # y_sum, y_squared_sum, global_batch_size = (
                #     replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [
                #         local_sum, local_squared_sum, batch_size]))

                # hvd_info(f'local_sum {local_sum.shape}, local_squared_sum {local_squared_sum.shape}')

                y_sum = hvd.allreduce(local_sum, average=False)
                y_squared_sum = hvd.allreduce(local_squared_sum, average=False)

                global_batch_size = batch_size * num_shards
                axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))]
                multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32)
                multiplier = multiplier * global_batch_size

                mean = y_sum / multiplier
                y_squared_mean = y_squared_sum / multiplier
                # var = E(x^2) - E(x)^2
                variance = y_squared_mean - math_ops.square(mean)
            else:
                # Compute true mean while keeping the dims for proper broadcasting.
                mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean')
                # sample variance, not unbiased variance
                # Note: stop_gradient does not change the gradient that gets
                #       backpropagated to the mean from the variance calculation,
                #       because that gradient is zero
                variance = math_ops.reduce_mean(
                    math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
                    axes,
                    keepdims=True,
                    name='variance')
            if not keep_dims:
                mean = array_ops.squeeze(mean, axes)
                variance = array_ops.squeeze(variance, axes)
            if x.dtype == dtypes.float16:
                return (math_ops.cast(mean, dtypes.float16),
                        math_ops.cast(variance, dtypes.float16))
            else:
                return (mean, variance)
Esempio n. 3
0
    def call(self, inputs):
        if self._channels_first:
            rank = inputs.shape.rank
            if rank and rank > 1:
                # Switch to channels-last format.
                permutation = [0]
                permutation.extend(range(2, rank))
                permutation.append(1)
                inputs = array_ops.transpose(inputs, perm=permutation)

        if context.executing_eagerly():
            # Full static shape is guaranteed to be available.
            # Performance: Using `constant_op` is much faster than passing a list.
            flattened_shape = constant_op.constant([inputs.shape[0], -1])
            return gen_array_ops.reshape(inputs, flattened_shape)
        else:
            input_shape = inputs.shape
            rank = input_shape.rank
            if rank == 1:
                return array_ops.expand_dims_v2(inputs, axis=1)
            else:
                batch_dim = tensor_shape.dimension_value(input_shape[0])
                non_batch_dims = input_shape[1:]
                # Reshape in a way that preserves as much shape info as possible.
                if non_batch_dims.is_fully_defined():
                    last_dim = int(
                        functools.reduce(operator.mul, non_batch_dims))
                    flattened_shape = constant_op.constant([-1, last_dim])
                elif batch_dim is not None:
                    flattened_shape = constant_op.constant(
                        [int(batch_dim), -1])
                else:
                    flattened_shape = [array_ops.shape_v2(inputs)[0], -1]
                return array_ops.reshape(inputs, flattened_shape)
Esempio n. 4
0
    def call(self, inputs):
        bins = [math_ops.cast(array_ops.squeeze(self.bins), dtypes.float32)]

        def _bucketize_fn(inputs):
            return gen_boosted_trees_ops.BoostedTreesBucketize(
                float_values=[math_ops.cast(inputs, dtypes.float32)],
                bucket_boundaries=bins)[0]

        if tf_utils.is_ragged(inputs):
            integer_buckets = ragged_functional_ops.map_flat_values(
                _bucketize_fn, inputs)
            # Ragged map_flat_values doesn't touch the non-values tensors in the
            # ragged composite tensor. If this op is the only op a Keras model,
            # this can cause errors in Graph mode, so wrap the tensor in an identity.
            return array_ops.identity(integer_buckets)
        elif isinstance(inputs, sparse_tensor.SparseTensor):
            return sparse_tensor.SparseTensor(
                indices=array_ops.identity(inputs.indices),
                values=_bucketize_fn(inputs.values),
                dense_shape=array_ops.identity(inputs.dense_shape))
        else:
            static_shape = inputs.get_shape()
            if any(dim is None for dim in static_shape.as_list()[1:]):
                raise NotImplementedError(
                    "Discretization Layer requires known non-batch shape,"
                    "found {}".format(static_shape))

            dynamic_shape = array_ops.shape_v2(inputs)
            # BoostedTreesBucketize only handles rank 1 inputs. We need to flatten our
            # inputs after batch size and vectorized_map over each sample.
            reshaped = array_ops.reshape(inputs, [dynamic_shape[0], -1])
            return array_ops.reshape(
                control_flow_ops.vectorized_map(_bucketize_fn, reshaped),
                dynamic_shape)
Esempio n. 5
0
def _pad_util(input_tensor, full_axis_dim):
    """Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
    missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
    tensor_rank = array_ops.rank(input_tensor)
    paddings_axis = [[0, missing_axis_dim]]
    paddings = array_ops.concat([
        paddings_axis,
        array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
    ],
                                axis=0)
    padded_input_tensor = array_ops.pad(input_tensor, paddings)
    return padded_input_tensor
Esempio n. 6
0
 def test_vocab_list_sparse_input(self):
   layer = categorical.CategoryLookup(
       vocabulary=self._wire_vocabulary_file_name, num_oov_tokens=0)
   inp = np.asarray([['omar', ''], ['stringer', 'marlo'], ['marlo', 'omar']])
   indices = array_ops.where_v2(math_ops.not_equal(inp, ''))
   sp_inp = sparse_tensor.SparseTensor(
       indices,
       array_ops.gather_nd_v2(inp, indices),
       dense_shape=array_ops.shape_v2(inp, out_type=dtypes.int64))
   output = layer(sp_inp)
   self.assertIsInstance(output, sparse_tensor.SparseTensor)
   self.assertAllClose(
       np.asarray([[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]), output.indices)
   self.assertAllClose(np.asarray([0, 1, 2, 2, 0]), output.values)
Esempio n. 7
0
 def test_vocab_list_sparse_input(self):
   vocabulary_list = ['A', 'B', 'C', 'D', 'E']
   layer = categorical.CategoryLookup(
       vocabulary=vocabulary_list, num_oov_tokens=0)
   inp = np.asarray([['A', ''], ['E', 'C'], ['D', 'A']])
   indices = array_ops.where_v2(math_ops.not_equal(inp, ''))
   sp_inp = sparse_tensor.SparseTensor(
       indices,
       array_ops.gather_nd_v2(inp, indices),
       dense_shape=array_ops.shape_v2(inp, out_type=dtypes.int64))
   output = layer(sp_inp)
   self.assertIsInstance(output, sparse_tensor.SparseTensor)
   self.assertAllClose(
       np.asarray([[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]), output.indices)
   self.assertAllClose(np.asarray([0, 4, 2, 3, 0]), output.values)
Esempio n. 8
0
        def slice_batch(index, batch):
            flattened_batch = nest.flatten(batch)
            flattened_output = []

            norm_index = math_ops.cast(index %
                                       self._num_local_devices_per_replica,
                                       dtype=dtypes.int32)
            norm_index += self._partition_offset
            coords = self._mesh.coords(norm_index)
            coords = array_ops.reshape(coords, (1, -1))

            for element, shard_counts, idx_matrix in zip(
                    flattened_batch, self._all_shard_counts,
                    self._index_matrices):
                indexes = math_ops.matmul(coords, idx_matrix)
                start = array_ops.reshape(indexes, (-1, ))
                size = array_ops.shape_v2(
                    element, out_type=dtypes.int32) // shard_counts
                flattened_output.append(
                    array_ops.slice(element, begin=start, size=size))

            return nest.pack_sequence_as(batch, flattened_output)
Esempio n. 9
0
    def _subdiv_calculate_mean_and_var(self, inputs, reduction_axes,
                                       keep_dims):
        # calculate the
        net_sum = math_ops.reduce_sum(inputs,
                                      axis=reduction_axes,
                                      keepdims=keep_dims)
        squared_mean = math_ops.reduce_sum(math_ops.square(inputs),
                                           axis=reduction_axes,
                                           keepdims=keep_dims)

        if self._support_zero_size_input():
            # Keras assumes that batch dimension is the first dimension for Batch
            # Normalization.
            input_batch_size = array_ops.shape(inputs)[0]
        else:
            input_batch_size = None

        # get the number of total params you are averaging including batchsize(local)
        axes_vals = [(array_ops.shape_v2(inputs))[i]
                     for i in range(1, len(reduction_axes))]
        multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
                                   dtypes.float32)

        squared_mean = squared_mean / multiplier
        net_sum = net_sum / multiplier

        if input_batch_size is None:
            mean, variance = nn.moments(inputs,
                                        reduction_axes,
                                        keep_dims=keep_dims)
        else:
            batches_ = math_ops.cast(input_batch_size, self._param_dtype)
            mean = net_sum / batches_
            variance = squared_mean / batches_ - math_ops.square(
                array_ops.stop_gradient(mean))

        return mean, net_sum, variance, squared_mean, input_batch_size
Esempio n. 10
0
 def model(in_tensor):
     shape = array_ops.shape_v2(in_tensor)
     fill = array_ops.transpose_v2(array_ops.fill(shape, 1.))
     return math_ops.matmul(fill, in_tensor)
Esempio n. 11
0
    def all_gather(self,
                   input_tensor,
                   axis,
                   communication_hint='AUTO',
                   timeout=0):
        """All-gather a dense tensor.

    This method must be called inside a tf.function.

    Args:
      input_tensor: a dense tensor. It must have the same rank on all replicas,
        and dimensions other than `axis` need to be the same as well.
      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
        range [0, rank(value)).
      communication_hint: string providing hint to runtime for choosing
        collective implementation. Available options are `AUTO`, `NCCL`, and
        `RING`.
      timeout: a float. The timeout in seconds.

    Returns:
      The gathered Tensor.

    Raises:
      RuntimeError: if called in eager mode.
    """
        if context.executing_eagerly():
            raise RuntimeError('all_gather in eager mode is not supported')

        with ops.device(self._device), \
             ops.control_dependencies([array_ops.identity(input_tensor)]):
            # 1. Transpose
            # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
            # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
            # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
            # place it back.
            perm_pre = array_ops.concat(
                ([axis], math_ops.range(axis),
                 math_ops.range(axis + 1, array_ops.rank(input_tensor))),
                axis=0)
            input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
            # 2. Pad
            gathered_shape = self._all_gather(array_ops.expand_dims_v2(
                array_ops.shape_v2(input_tensor_t), axis=0),
                                              communication_hint,
                                              timeout=timeout)
            first_dims = gathered_shape[:, 0]
            full_axis_dim = math_ops.reduce_max(first_dims)
            padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)

            # 3. Gather
            gather_padded_out_tensor = self._all_gather(padded_input_tensor,
                                                        communication_hint,
                                                        timeout=timeout)
            # 4. Unpad
            split_tensors = []
            for i in range(self._group_size):
                start_pos = i * full_axis_dim
                split_tensors.append(
                    gather_padded_out_tensor[start_pos:start_pos +
                                             first_dims[i]])
            out_tensor_t = array_ops.concat(split_tensors, 0)

            # 5. Transpose back
            perm_after = array_ops.concat(
                (math_ops.range(1, axis + 1), [0],
                 math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
                axis=0)
            return array_ops.transpose(out_tensor_t, perm=perm_after)
Esempio n. 12
0
    def call(self, y_true, y_pred):
        """Invokes the `Loss` instance.

        Args:
            y_true: Ground truth values.
            y_pred: The predicted values.

        Returns:
            Loss values in the form of a Tensor
        """
        gamma = self.gamma
        from_logits = self.from_logits
        axis = -1

        y_true = tf.cast(y_true, y_pred.dtype)
        y_true = ops.convert_to_tensor_v2(y_true)
        y_pred = ops.convert_to_tensor_v2(y_pred)

        probs = y_pred

        # Reformat y_pred shapes
        if (not from_logits and
                not isinstance(y_pred,
                               (ops.EagerTensor, variables_module.Variable))
                and y_pred.op.type == 'Softmax') and not hasattr(
                    y_pred, '_keras_history'):
            assert len(y_pred.op.inputs) == 1
            y_pred = y_pred.op.inputs[0]
            from_logits = True

        # Clip y_pred to a minimum and maximum value
        if not from_logits:
            epsilon_ = constant_op.constant(K.epsilon(),
                                            y_pred.dtype.base_dtype)
            y_pred = clip_ops.clip_by_value(y_pred, epsilon_, 1 - epsilon_)
            y_pred = math_ops.log(y_pred)

        # Get dimensions of predictions tensor
        if isinstance(y_pred.shape, (tuple, list)):
            output_rank = len(y_pred.shape)
        else:
            output_rank = y_pred.shape.ndims
        if output_rank is not None:
            axis %= output_rank
            if axis != output_rank - 1:
                permutation = list(
                    itertools.chain(range(axis), range(axis + 1, output_rank),
                                    [axis]))
                y_pred = array_ops.transpose(y_pred, perm=permutation)
        elif axis != -1:
            raise ValueError(
                'Cannot compute sparse categorical crossentropy with `axis={}` on an '
                'output tensor with unknown rank'.format(axis))

        # Reformat y_true shape and data type.
        y_true = cast(y_true, 'int64')

        output_shape = array_ops.shape_v2(y_pred)
        target_rank = y_true.shape.ndims

        update_shape = (target_rank is not None and output_rank is not None
                        and target_rank != output_rank - 1)
        if update_shape:
            y_true = flatten(y_true)
            y_pred = array_ops.reshape(y_pred, [-1, output_shape[-1]])

        # Calculate cross-entropy loss
        if py_any(_is_symbolic_tensor(v) for v in [y_true, y_pred]):
            with get_graph().as_default():
                loss = nn.sparse_softmax_cross_entropy_with_logits_v2(
                    labels=y_true, logits=y_pred)
        else:
            loss = nn.sparse_softmax_cross_entropy_with_logits_v2(
                labels=y_true, logits=y_pred)

        if update_shape and output_rank >= 3:
            loss = array_ops.reshape(loss, output_shape[:-1])

        # Calculate focal modulation to be applied
        gamma = tf.convert_to_tensor(gamma, dtype=tf.dtypes.float32)
        scalar_gamma = gamma.shape.rank == 0

        y_true_rank = y_true.shape.rank
        if not scalar_gamma:
            gamma = tf.gather(gamma, y_true, axis=0, batch_dims=y_true_rank)

        focal_modulation = K.pow(1 - tf.math.reduce_mean(probs, axis=1), gamma)
        focal_modulation = tf.gather(focal_modulation,
                                     y_true,
                                     axis=0,
                                     batch_dims=y_true_rank)

        loss = focal_modulation * loss

        return loss
Esempio n. 13
0
    def all_gather(
            self,
            input_tensor: core.TensorLike,
            axis: core.TensorLike,
            options: Optional[collective_util.Options] = None) -> core.Tensor:
        """All-gather a dense tensor.

    This method must be called inside a tf.function.

    Args:
      input_tensor: a dense tensor. It must have the same rank on all replicas,
        and dimensions other than `axis` need to be the same as well.
      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
        range [0, rank(value)).
      options: an optional tf.distribute.experimental.CommunicationOptions. If
        provided, it overrides the default options.

    Returns:
      The gathered Tensor.

    Raises:
      RuntimeError: if called in eager mode.
    """
        if context.executing_eagerly():
            raise RuntimeError('all_gather is not supported in eager mode.')

        with ops.device(self._device), \
             ops.control_dependencies([array_ops.identity(input_tensor)]):
            # 1. Transpose
            # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
            # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
            # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
            # place it back.
            perm_pre = array_ops.concat(
                ([axis], math_ops.range(axis),
                 math_ops.range(axis + 1, array_ops.rank(input_tensor))),
                axis=0)
            input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
            # 2. Pad
            gathered_shape = self._all_gather(
                array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t),
                                         axis=0), options)
            first_dims = gathered_shape[:, 0]
            full_axis_dim = math_ops.reduce_max(first_dims)
            padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)

            # 3. Gather
            gather_padded_out_tensor = self._all_gather(
                padded_input_tensor, options)
            # 4. Unpad
            split_tensors = []
            for i in range(self._group_size):
                start_pos = i * full_axis_dim
                split_tensors.append(
                    gather_padded_out_tensor[start_pos:start_pos +
                                             first_dims[i]])
            out_tensor_t = array_ops.concat(split_tensors, 0)

            # 5. Transpose back
            perm_after = array_ops.concat(
                (math_ops.range(1, axis + 1), [0],
                 math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
                axis=0)
            return array_ops.transpose(out_tensor_t, perm=perm_after)
Esempio n. 14
0
    def _subdiv_calculate_mean_and_var(self, x, axes, keep_dims):

        with K.name_scope('moments'):
            # The dynamic range of fp16 is too limited to support the collection of
            # sufficient statistics. As a workaround we simply perform the operations
            # on 32-bit floats before converting the mean and variance back to fp16
            y = math_ops.cast(
                x, dtypes.float32) if x.dtype == dtypes.float16 else x
            replica_ctx = ds.get_replica_context()

            if replica_ctx:
                # local to me

                local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
                local_squared_sum = math_ops.reduce_sum(math_ops.square(y),
                                                        axis=axes,
                                                        keepdims=True)
                batch_size = math_ops.cast(
                    array_ops.shape_v2(y)[0], dtypes.float32)
                # TODO(b/163099951): batch the all-reduces once we sort out the ordering
                # issue for NCCL. We don't have a mechanism to launch NCCL in the same
                # order in each replica nowadays, so we limit NCCL to batch all-reduces.
                # get the sum of all replicas (converge all devices)
                y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
                                               local_sum)
                # get the sum from all replicas (converge all devices)
                y_squared_sum = replica_ctx.all_reduce(
                    reduce_util.ReduceOp.SUM, local_squared_sum)
                # get the net batch size from all devices (converge all devices)
                input_batch_size = replica_ctx.all_reduce(
                    reduce_util.ReduceOp.SUM, batch_size)

                #tf.print(replica_ctx.replica_id_in_sync_group, replica_ctx.num_replicas_in_sync, batch_size, self.aggregated_square_sum_batch, axes)
                # get the number of total params you are averaging (local)
                axes_vals = [(array_ops.shape_v2(y))[i]
                             for i in range(1, len(axes))]
                multiplier_ = math_ops.cast(math_ops.reduce_prod(axes_vals),
                                            dtypes.float32)
                multiplier = multiplier_ * input_batch_size

                # conver mean var (locally)
                mean = y_sum / multiplier
                y_squared_mean = y_squared_sum / multiplier
                # var = E(x^2) - E(x)^2
                variance = y_squared_mean - math_ops.square(mean)
                net_sum = y_sum / multiplier_
                squared_mean = y_squared_sum / multiplier_

            else:
                # mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean')
                # # sample variance, not unbiased variance
                # # Note: stop_gradient does not change the gradient that gets
                # #       backpropagated to the mean from the variance calculation,
                # #       because that gradient is zero
                # variance = math_ops.reduce_mean(
                #     math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
                #     axes,
                #     keepdims=True,
                #     name='variance')

                net_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
                squared_mean = math_ops.reduce_sum(math_ops.square(y),
                                                   axis=axes,
                                                   keepdims=True)

                if self._support_zero_size_input():
                    # Keras assumes that batch dimension is the first dimension for Batch
                    # Normalization.
                    input_batch_size = array_ops.shape(y)[0]
                else:
                    input_batch_size = None

                # get the number of total params you are averaging including batchsize(local)
                axes_vals = [(array_ops.shape_v2(y))[i]
                             for i in range(1, len(axes))]
                multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
                                           dtypes.float32)

                squared_mean = squared_mean / multiplier
                net_sum = net_sum / multiplier

                if input_batch_size is None:
                    mean, variance = nn.moments(y, axes, keep_dims=True)
                    input_batch_size = 0
                else:
                    batches_ = math_ops.cast(input_batch_size,
                                             self._param_dtype)
                    # # if you only have one replica dont worry about it
                    # # Compute true mean while keeping the dims for proper broadcasting.
                    mean = net_sum / batches_
                    variance = squared_mean / batches_ - math_ops.square(mean)

            input_batch_size = math_ops.cast(input_batch_size, dtypes.int32)
            if not keep_dims:
                mean = array_ops.squeeze(mean, axes)
                net_sum = array_ops.squeeze(net_sum, axes)
                variance = array_ops.squeeze(variance, axes)
                squared_mean = array_ops.squeeze(squared_mean, axes)
            if x.dtype == dtypes.float16:
                return (math_ops.cast(mean, dtypes.float16),
                        math_ops.cast(net_sum, dtypes.float16),
                        math_ops.cast(variance, dtypes.float16),
                        math_ops.cast(squared_mean,
                                      dtypes.float16), input_batch_size)
            else:
                return (mean, net_sum, variance, squared_mean,
                        input_batch_size)
def build_collective_gather(input_tensors,
                            devices,
                            group_size,
                            collective_keys,
                            axis,
                            communication_hint='AUTO',
                            control_inputs=None,
                            timeout=None):
  """Build a subgraph that does one full all-gather, using the collective Op.

  This method must be called in graph mode or inside a tf.function.

  Args:
    input_tensors: tensors within a single worker graph that are to be gathered
      together; must be one per device. Input tensors cannot have rank 0.
    devices: a list of device strings to run the collective on.
    group_size: total number of devices globally that will be doing this same
      gathering. The gathering will actually include the corresponding tensors
      at all these workers.
    collective_keys: a CollectiveKeys object.
    axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
      range [0, rank(value)).
    communication_hint: string providing hint to runtime for choosing collective
      implementation. Available options are `AUTO`, `NCCL`, and `RING`.
    control_inputs: if not None, add control edges between control_inputs and
      (index-wise) corresponding collective_gather tensors
    timeout: a float or None. The timeout in seconds.

  Returns:
    An array of final tensors, one per device, computed by the full gather.
  """
  if len(input_tensors) != len(devices):
    raise ValueError(
        'collective requires one input tensor for each device, %d != %d' %
        (len(input_tensors), len(devices)))

  if group_size < 2:
    return input_tensors
  group_key = collective_keys.get_group_key(devices)
  instance_key_tensor = collective_keys.get_op_instance_key()
  instance_key_shape = collective_keys.get_op_instance_key()

  out_tensors = []
  for idx, input_tensor in enumerate(input_tensors):
    with ops.device(devices[idx]), ops.control_dependencies(
        _control_input(devices, control_inputs, idx)):
      # 1. Transpose
      # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
      # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
      # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
      # place it back.
      perm_pre = array_ops.concat(
          ([axis], math_ops.range(axis),
           math_ops.range(axis + 1, array_ops.rank(input_tensor))),
          axis=0)
      input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
      # 2. Pad
      gathered_shape = collective_ops.all_gather(
          array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
          group_size,
          group_key,
          instance_key_shape,
          communication_hint,
          timeout=timeout)
      first_dims = gathered_shape[:, 0]
      full_axis_dim = math_ops.reduce_max(first_dims)
      padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)

      # 3. Gather
      gather_padded_out_tensor = collective_ops.all_gather(
          padded_input_tensor,
          group_size,
          group_key,
          instance_key_tensor,
          communication_hint,
          timeout=timeout)
      # 4. Unpad
      split_tensors = []
      for i in range(first_dims.shape[0]):
        start_pos = i * full_axis_dim
        split_tensors.append(gather_padded_out_tensor[start_pos:start_pos +
                                                      first_dims[i]])
      out_tensor_t = array_ops.concat(split_tensors, 0)

      # 5. Transpose back
      perm_after = array_ops.concat(
          (math_ops.range(1, axis + 1), [0],
           math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
          axis=0)
      out_tensor = array_ops.transpose(out_tensor_t, perm=perm_after)
      out_tensors.append(out_tensor)
  return out_tensors