def dense(x, num_outputs, use_bias=True, name='dense'):
    """Custom fully connected layer."""
    num_inputs = x.shape[-1].value
    with tf.variable_scope(name):
        w = tf.get_variable(name='kernel',
                            shape=[num_inputs, num_outputs],
                            initializer=_dense_kernel_initializer,
                            trainable=True,
                            use_resource=True)
        w = common_utils.shard_weight(w, NUM_XLA_SHARDS)
        if USE_BFLOAT16:
            w = tf.cast(w, tf.bfloat16)

        x = tf.linalg.matmul(x, w)
        if use_bias:
            b = tf.get_variable(name='bias',
                                shape=[num_outputs],
                                initializer=tf.initializers.zeros(),
                                trainable=True,
                                use_resource=True)
            b = common_utils.shard_weight(b, NUM_XLA_SHARDS)
            if USE_BFLOAT16:
                b = tf.cast(b, tf.bfloat16)
            x = tf.nn.bias_add(x, b, name='bias_add')
        return x
def conv2d(x, filter_size, num_out_filters, stride=1,
           use_bias=False, padding='SAME', data_format='NHWC', name='conv2d',
           w=None, b=None):
  """Conv."""
  with tf.variable_scope(name):
    num_inp_filters = x.shape[-1].value

    w = tf.get_variable(
        name='kernel',
        shape=[filter_size, filter_size, num_inp_filters, num_out_filters],
        initializer=_conv_kernel_initializer,
        trainable=True,
        use_resource=True)

    w = common_utils.shard_weight(w, NUM_XLA_SHARDS)

    if USE_BFLOAT16:
      w = tf.cast(w, tf.bfloat16)
    x = tf.nn.conv2d(x, w, strides=[1, stride, stride, 1],
                     padding=padding, data_format=data_format)

    if use_bias:
      if b is None:
        b = tf.get_variable(
            name='bias',
            shape=[num_out_filters],
            initializer=tf.initializers.zeros(),
            trainable=True,
            use_resource=True)
        b = common_utils.shard_weight(b, NUM_XLA_SHARDS)
        if USE_BFLOAT16:
          b = tf.cast(b, tf.bfloat16)
      x = tf.nn.bias_add(x, b, name='bias_add')
    return x
def dw_conv2d(x,
              filter_size,
              stride,
              depth_multiplier=1,
              padding='SAME',
              data_format='NHWC',
              name='dw_conv_2d',
              w=None):
    """Custom depthwise conv."""
    if depth_multiplier > 1:
        raise NotImplementedError('Bite me!')

    with tf.variable_scope(name):
        num_inp_filters = x.shape[-1].value
        w = tf.get_variable(
            name='depthwise_kernel',
            shape=[filter_size, filter_size, num_inp_filters, 1],
            initializer=_conv_kernel_initializer,
            trainable=True,
            use_resource=True)

        w = common_utils.shard_weight(w, NUM_XLA_SHARDS)

        if USE_BFLOAT16:
            w = tf.cast(w, tf.bfloat16)
        x = tf.nn.depthwise_conv2d(x,
                                   filter=w,
                                   strides=[1, stride, stride, 1],
                                   padding=padding,
                                   data_format=data_format)
        return x
def sync_batch_norm(x, params, training, name='batch_norm'):
    """Sync batch_norm."""
    size = x.shape[-1].value

    with tf.variable_scope(name):
        gamma = tf.get_variable(name='gamma',
                                shape=[size],
                                initializer=tf.initializers.ones(),
                                trainable=True)
        beta = tf.get_variable(name='beta',
                               shape=[size],
                               initializer=tf.initializers.zeros(),
                               trainable=True)
        moving_mean = tf.get_variable(name='moving_mean',
                                      shape=[size],
                                      initializer=tf.initializers.zeros(),
                                      trainable=False)
        moving_variance = tf.get_variable(name='moving_variance',
                                          shape=[size],
                                          initializer=tf.initializers.ones(),
                                          trainable=False)

        gamma = common_utils.shard_weight(gamma, NUM_XLA_SHARDS)
        beta = common_utils.shard_weight(beta, NUM_XLA_SHARDS)
        if not training:
            moving_mean = common_utils.shard_weight(moving_mean,
                                                    NUM_XLA_SHARDS)
            moving_variance = common_utils.shard_weight(
                moving_variance, NUM_XLA_SHARDS)

    x = tf.cast(x, tf.float32)
    if training:
        if params.use_tpu:
            num_replicas = params.num_replicas
            if num_replicas <= 8:
                group_assign = None
                group_shards = tf.cast(num_replicas, tf.float32)
            else:
                if params.batch_norm_batch_size is None:
                    group_shards = max(8, num_replicas // 8)
                else:
                    group_shards = params.batch_norm_batch_size // x.shape[
                        0].value
                    group_shards = min(group_shards, params.num_replicas)
                    group_shards = max(group_shards, 1)

                # round to nearest power of 2
                log_num_replicas = max(1,
                                       int(np.log(group_shards) / np.log(2.)))
                group_shards = int(np.power(2., log_num_replicas))

                group_assign = np.arange(num_replicas, dtype=np.int32)
                group_assign = group_assign.reshape([-1, group_shards])
                group_assign = group_assign.tolist()
                group_shards = tf.cast(group_shards, tf.float32)

            mean = tf.reduce_mean(x, [0, 1, 2])
            mean = tf.tpu.cross_replica_sum(mean / group_shards, group_assign)

            # Var[x] = E[x^2] - E[x]^2
            mean_sq = tf.reduce_mean(tf.math.square(x), [0, 1, 2])
            mean_sq = tf.tpu.cross_replica_sum(mean_sq / group_shards,
                                               group_assign)
            variance = mean_sq - tf.math.square(mean)
        else:
            mean, variance = tf.nn.moments(x, [0, 1, 2])

        x = tf.nn.batch_normalization(
            x,
            mean=mean,
            variance=variance,
            offset=beta,
            scale=gamma,
            variance_epsilon=params.batch_norm_epsilon)

        if USE_BFLOAT16:
            x = tf.cast(x, tf.bfloat16, name='batch_norm_recast')

        if (isinstance(moving_mean, tf.Variable)
                and isinstance(moving_variance, tf.Variable)):
            decay = tf.cast(1. - params.batch_norm_decay, tf.float32)

            def u(moving, normal, name):
                if params.use_tpu:
                    num_replicas_fp = tf.cast(params.num_replicas, tf.float32)
                    normal = tf.tpu.cross_replica_sum(normal) / num_replicas_fp
                diff = decay * (moving - normal)
                return tf.assign_sub(moving, diff, use_locking=True, name=name)

            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
                                 u(moving_mean, mean, name='moving_mean'))
            tf.add_to_collection(
                tf.GraphKeys.UPDATE_OPS,
                u(moving_variance, variance, name='moving_variance'))
            return x
        else:
            return x, mean, variance
    else:
        if params.use_tpu:
            x = tf.nn.batch_normalization(
                x,
                mean=moving_mean,
                variance=moving_variance,
                offset=beta,
                scale=gamma,
                variance_epsilon=params.batch_norm_epsilon)
        else:
            x, _, _ = tf.nn.fused_batch_norm(x,
                                             scale=gamma,
                                             offset=beta,
                                             mean=moving_mean,
                                             variance=moving_variance,
                                             epsilon=params.batch_norm_epsilon,
                                             is_training=False)

        if USE_BFLOAT16:
            x = tf.cast(x, tf.bfloat16)
        return x