Esempio n. 1
0
def _relative_position_bucket(relative_position,
                              bidirectional=True,
                              num_buckets=32,
                              max_distance=128):
  """Translate relative position to a bucket number for relative attention.

  The relative position is defined as memory_position - query_position, i.e.
  the distance in tokens from the attending position to the attended-to
  position.  If bidirectional=False, then positive relative positions are
  invalid.

  We use smaller buckets for small absolute relative_position and larger buckets
  for larger absolute relative_positions.  All relative positions >=max_distance
  map to the same bucket.  All relative positions <=-max_distance map to the
  same bucket.  This should allow for more graceful generalization to longer
  sequences than the model has been trained on.

  Args:
    relative_position: an int32 Tensor
    bidirectional: a boolean - whether the attention is bidirectional
    num_buckets: an integer
    max_distance: an integer
  Returns:
    a Tensor with the same shape as relative_position, containing int32
      values in the range [0, num_buckets)
  """
  ret = 0
  n = -relative_position
  if bidirectional:
    num_buckets //= 2
    ret += mtf.to_int32(mtf.less(n, 0)) * num_buckets
    n = mtf.abs(n)
  else:
    n = mtf.maximum(n, 0)
  # now n is in the range [0, inf)
  max_exact = num_buckets // 2
  is_small = mtf.less(n, max_exact)
  val_if_large = max_exact + mtf.to_int32(
      mtf.log(mtf.to_float(n) / max_exact)
      / math.log(max_distance / max_exact) * (num_buckets - max_exact))
  val_if_large = mtf.minimum(val_if_large, num_buckets - 1)
  ret += mtf.where(is_small, n, val_if_large)
  return ret
Esempio n. 2
0
 def _noisy_targets_from_spec(self, targets, noising_spec, losses=None):
     if noising_spec["type"] == "mask":
         # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0.
         return targets * mtf.cast(
             mtf.greater(mtf.random_uniform(targets.mesh, targets.shape),
                         noising_spec["prob"]), targets.dtype)
     elif noising_spec["type"] == "random_zipfian":
         # Replace a randomly-chosen noising_spec["prob"] of input tokens.
         # Rather than drawing the replacement tokens uniformly, we sample from
         #   a distribution favoring lower token-ids, assuming that the ids have
         #   been assigned in frequency order.  The probability of choosing an
         #   id is proportional to 1/(id+10)
         logits = mtf.log(1.0 / (mtf.range(
             targets.mesh, self.targets_vocab_dim, dtype=tf.float32) +
                                 10.0))
         logits = mtf.broadcast(logits,
                                new_shape=targets.shape + logits.shape)
         r = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
         use_noise = mtf.less(
             mtf.random_uniform(targets.mesh, targets.shape),
             noising_spec["prob"])
         return mtf.where(use_noise, r, targets)
     elif noising_spec["type"] == "transformer":
         # Train a small transformer to fill in masked out values, then
         # sample from it.
         hparams = self._hparams
         if hparams.mode != tf.estimator.ModeKeys.TRAIN:
             raise NotImplementedError("Not implemented")
         noiser_hparams = copy.copy(self._hparams)
         noiser_hparams.del_hparam("mode")
         noiser_hparams.override_from_dict(noising_spec["overrides"])
         with tf.variable_scope("noiser"):
             noiser = MtfTransformer(noiser_hparams,
                                     mode=hparams.mode,
                                     problem_hparams=self._problem_hparams)
             logits, loss = noiser._mtf_model_fn(  # pylint: disable=protected-access
                 self._original_features, targets.mesh)
             samples = mtf.sample_with_temperature(logits,
                                                   self.targets_vocab_dim)
         losses.append(loss)
         return samples
     else:
         raise ValueError("unknown noising spec %s" % noising_spec)
Esempio n. 3
0
def get_activation_fn(params):
    activation_fn = params.get("activation_fn", "gelu")
    if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415
        return mtf.gelu
    elif activation_fn == "relu":
        return mtf.relu
    elif activation_fn == "sigmoid":
        return mtf.sigmoid
    elif activation_fn == "tanh":
        return mtf.tanh
    elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515
        return mtf.selu
    elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289
        return mtf.elu
    
    # swish activations
    elif activation_fn == "swish": # https://arxiv.org/abs/1710.05941
        return mtf.swish
    
    # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671
    elif activation_fn == "maxsig": 
        return lambda x: mtf.maximum(x, mtf.sigmoid(x))
    elif activation_fn == "cosid": 
        return lambda x: mtf.cos(x) - x
    elif activation_fn == "minsin": 
        return lambda x: mtf.minimum(x, mtf.sin(x))
    elif activation_fn == "maxtanh": 
        return lambda x: mtf.maximum(x, mtf.tanh(x))
    
    elif activation_fn == "softplus":
        return mtf.softplus
    elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681
        return lambda x: x * mtf.tanh(mtf.softplus(x))
    elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855
        return lambda x: x * mtf.tanh(mtf.exp(x))
    elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894
        return lambda x: x * mtf.tanh(x)
    elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713
        return lambda x: mtf.log(1 + x ** 2)
    elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195
        return lambda x: x + mtf.sin(x) ** 2
    else:
        raise ValueError('unknown activation function "activation_fn" in config')
Esempio n. 4
0
 def _noisy_targets_from_spec(self, targets, noising_spec, losses=None):
   if noising_spec["type"] == "mask":
     # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0.
     return targets * mtf.cast(
         mtf.greater(mtf.random_uniform(targets.mesh, targets.shape),
                     noising_spec["prob"]), targets.dtype)
   elif noising_spec["type"] == "random_zipfian":
     # Replace a randomly-chosen noising_spec["prob"] of input tokens.
     # Rather than drawing the replacement tokens uniformly, we sample from
     #   a distribution favoring lower token-ids, assuming that the ids have
     #   been assigned in frequency order.  The probability of choosing an
     #   id is proportional to 1/(id+10)
     logits = mtf.log(1.0 / (mtf.range(
         targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0))
     logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape)
     r = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
     use_noise = mtf.less(
         mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"])
     return mtf.where(use_noise, r, targets)
   elif noising_spec["type"] == "transformer":
     # Train a small transformer to fill in masked out values, then
     # sample from it.
     hparams = self._hparams
     if hparams.mode != tf.estimator.ModeKeys.TRAIN:
       raise NotImplementedError("Not implemented")
     noiser_hparams = copy.copy(self._hparams)
     noiser_hparams.del_hparam("mode")
     noiser_hparams.override_from_dict(noising_spec["overrides"])
     with tf.variable_scope("noiser"):
       noiser = MtfTransformer(
           noiser_hparams,
           mode=hparams.mode,
           problem_hparams=self._problem_hparams)
       logits, loss = noiser._mtf_model_fn(  # pylint: disable=protected-access
           self._original_features, targets.mesh)
       samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
     losses.append(loss)
     return samples
   else:
     raise ValueError("unknown noising spec %s" % noising_spec)
Esempio n. 5
0
def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
    if targets.dtype.is_integer:
        # hard targets
        if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])):
            raise ValueError(
                "softmax_cross_entropy_with_logits with hard targets "
                "dims in targets=%s should be dims in logits=%s other than "
                "vocab_dim=%s" % (targets, logits, vocab_dim))
        targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype)
    elif set(targets.shape.dims) != set(logits.shape.dims):
        raise ValueError(
            "softmax_cross_entropy_with_logits with soft targets "
            "dims in targets=%s should be dims in logits=%s" % (targets, logits))

    if vocab_dim not in logits.shape.dims:
        raise ValueError("vocab_dim must be in logits.shape.dims")

    log_entmax = mtf.log(entmax(logits, dim=vocab_dim))

    loss = mtf.negative(
        mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim))

    return loss
Esempio n. 6
0
 def _arcsinh(x):
     return mtf.log(x + mtf.sqrt(1 + x**2))
Esempio n. 7
0
def get_activation_fn(params):
    activation_fn = params.get("activation_fn", "gelu")

    def _arcsinh(x):
        return mtf.log(x + mtf.sqrt(1 + x**2))

    def _var(x, init):
        return mtf.get_variable(x.mesh,
                                f"activation-{random.randint(0, 2 ** 32):x}",
                                [],
                                initializer=tf.constant_initializer(init),
                                dtype=x.dtype)

    def _pos_var(x, val):
        return mtf.softplus(_var(x, 0)) + val

    if activation_fn == "gelu":  # https://arxiv.org/abs/1606.08415
        return mtf.gelu
    elif activation_fn == "relu":
        return mtf.relu
    elif activation_fn == "sigmoid":
        return mtf.sigmoid
    elif activation_fn == "tanh":
        return mtf.tanh
    elif activation_fn == "selu":  # https://arxiv.org/abs/1706.02515
        return mtf.selu
    elif activation_fn == "elu":  # https://arxiv.org/abs/1511.07289
        return mtf.elu
    elif activation_fn == "lrelu001":
        return lambda x: mtf.leaky_relu(x, alpha=0.01)
    elif activation_fn == "lrelu020":
        return lambda x: mtf.leaky_relu(x, alpha=0.20)

    elif activation_fn == "abs":
        return mtf.abs
    elif activation_fn == "id":
        return lambda x: x
    elif activation_fn == "sin":
        return mtf.sin
    elif activation_fn == "cos":
        return mtf.cos
    elif activation_fn == "sign":
        return mtf.sign
    elif activation_fn == "triangle_relax":
        return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(
            5 * x) / 25 - mtf.sin(7 * x) / 49
    elif activation_fn == "square_relax":
        return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(
            5 * x) / 5 - mtf.cos(7 * x) / 7
    elif activation_fn == "spike":
        return lambda x: 1 / (1 + x**2)
    elif activation_fn == "spike2":
        return lambda x: mtf.exp(-x**2)

    elif activation_fn == "tanhshrink":
        return lambda x: x - tanh(x)
    elif activation_fn == "softsign":
        return lambda x: x / (mtf.abs(x) + 1)
    elif activation_fn == "softmax":
        return lambda x: mtf.softmax(x, x.shape[-1])
    elif activation_fn == "logsoftmax":
        return lambda x: mtf.log_softmax(x, x.shape[-1])
    elif activation_fn == "bipolarsigmoid":
        return lambda x: mtf.sigmoid(x) * 2 - 1
    elif activation_fn == "rrelu":  # https://arxiv.org/abs/1505.00853

        def _rrelu_fn(x):
            negative_scale = random.random()
            return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)

        return _rrelu_fn
    elif activation_fn == "elish":  # https://arxiv.org/abs/1808.00783v1

        def _elish_fn(x):
            cond = mtf.cast(mtf.greater(x, 0), x.dtype)
            exp = mtf.exp(x)
            return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp +
                                                                    1)

        return _elish_fn

    elif activation_fn == "silu":  # https://arxiv.org/abs/1710.05941
        return mtf.swish

    elif activation_fn == "arcsinh":
        return _arcsinh

    # parametric
    elif activation_fn == "aria":  # https://arxiv.org/abs/1805.08878
        return lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var(
            x, 1) * mtf.exp(_var(x, -1) * x)**(1 / _pos_var(x, 1))))
    elif activation_fn == "prelu":  # https://arxiv.org/abs/1502.01852
        return lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2))
    elif activation_fn == "parcsinh":
        return lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1))
    elif activation_fn == "psoftplus":
        return lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0)
    elif activation_fn == "proottanh":
        return lambda x: (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var(
            x, 3)) * mtf.tanh(x)

    # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671
    elif activation_fn == "maxsig":
        return lambda x: mtf.maximum(x, mtf.sigmoid(x))
    elif activation_fn == "cosid":
        return lambda x: mtf.cos(x) - x
    elif activation_fn == "minsin":
        return lambda x: mtf.minimum(x, mtf.sin(x))
    elif activation_fn == "maxtanh":
        return lambda x: mtf.maximum(x, mtf.tanh(x))

    elif activation_fn == "softplus":
        return mtf.softplus
    elif activation_fn == "mish":  # https://arxiv.org/abs/1908.08681
        return lambda x: x * mtf.tanh(mtf.softplus(x))
    elif activation_fn == "tanhexp":  # https://arxiv.org/abs/2003.09855
        return lambda x: x * mtf.tanh(mtf.exp(x))
    elif activation_fn == "lisht":  # https://arxiv.org/abs/1901.05894
        return lambda x: x * mtf.tanh(x)
    elif activation_fn == "seagull":  # https://arxiv.org/abs/2011.11713
        return lambda x: mtf.log(1 + x**2)
    elif activation_fn == "snake":  # https://arxiv.org/abs/2006.08195
        return lambda x: x + mtf.sin(x)**2

    elif activation_fn == "roottanh":  # made up
        return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x)
    elif activation_fn == "softplusmone":  # made up
        return lambda x: mtf.softplus(x) - 1

    else:
        raise ValueError(
            'unknown activation function "activation_fn" in config')
Esempio n. 8
0
def get_activation_fn(params):
    activation_fn = params.get("activation_fn", "gelu")
    if activation_fn == "gelu":  # https://arxiv.org/abs/1606.08415
        return mtf.gelu
    elif activation_fn == "relu":
        return mtf.relu
    elif activation_fn == "sigmoid":
        return mtf.sigmoid
    elif activation_fn == "tanh":
        return mtf.tanh
    elif activation_fn == "selu":  # https://arxiv.org/abs/1706.02515
        return mtf.selu
    elif activation_fn == "elu":  # https://arxiv.org/abs/1511.07289
        return mtf.elu

    elif activation_fn == "abs":
        return mtf.abs
    elif activation_fn == "id":
        return lambda x: x
    elif activation_fn == "sin":
        return mtf.sin
    elif activation_fn == "cos":
        return mtf.cos
    elif activation_fn == "sign":
        return mtf.sign
    elif activation_fn == "triangle_relax":
        return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(
            5 * x) / 25 - mtf.sin(7 * x) / 49
    elif activation_fn == "square_relax":
        return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(
            5 * x) / 5 - mtf.cos(7 * x) / 7
    elif activation_fn == "spike":
        return lambda x: 1 / (1 + x**2)
    elif activation_fn == "spike2":
        return lambda x: mtf.exp(-x**2)

    elif activation_fn == "tanhshrink":
        return lambda x: x - tanh(x)
    elif activation_fn == "softsign":
        return lambda x: x / (mtf.abs(x) + 1)
    elif activation_fn == "softmax":
        return lambda x: mtf.softmax(x, x.shape[-1])
    elif activation_fn == "logsoftmax":
        return lambda x: mtf.log_softmax(x, x.shape[-1])
    elif activation_fn == "bipolarsigmoid":
        return lambda x: mtf.sigmoid(x) * 2 - 1
    elif activation_fn == "rrelu":  # https://arxiv.org/abs/1505.00853

        def _rrelu_fn(x):
            negative_scale = random.random()
            return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)

        return _rrelu_fn
    elif activation_fn == "elish":  # https://arxiv.org/abs/1808.00783v1

        def _elish_fn(x):
            cond = mtf.cast(mtf.greater(x, 0), x.dtype)
            exp = mtf.exp(x)
            return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp +
                                                                    1)

        return _elish_fn

    # swish activations
    elif activation_fn == "swish":  # https://arxiv.org/abs/1710.05941
        return mtf.swish

    # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671
    elif activation_fn == "maxsig":
        return lambda x: mtf.maximum(x, mtf.sigmoid(x))
    elif activation_fn == "cosid":
        return lambda x: mtf.cos(x) - x
    elif activation_fn == "minsin":
        return lambda x: mtf.minimum(x, mtf.sin(x))
    elif activation_fn == "maxtanh":
        return lambda x: mtf.maximum(x, mtf.tanh(x))

    elif activation_fn == "softplus":
        return mtf.softplus
    elif activation_fn == "mish":  # https://arxiv.org/abs/1908.08681
        return lambda x: x * mtf.tanh(mtf.softplus(x))
    elif activation_fn == "tanhexp":  # https://arxiv.org/abs/2003.09855
        return lambda x: x * mtf.tanh(mtf.exp(x))
    elif activation_fn == "lisht":  # https://arxiv.org/abs/1901.05894
        return lambda x: x * mtf.tanh(x)
    elif activation_fn == "seagull":  # https://arxiv.org/abs/2011.11713
        return lambda x: mtf.log(1 + x**2)
    elif activation_fn == "snake":  # https://arxiv.org/abs/2006.08195
        return lambda x: x + mtf.sin(x)**2

    elif activation_fn == "roottanh":  # made up
        return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x)
    elif activation_fn == "softplusmone":  # made up
        return lambda x: mtf.softplus(x) - 1

    else:
        raise ValueError(
            'unknown activation function "activation_fn" in config')
Esempio n. 9
0
def _rand_1_gating(
    inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
    hparams, train, variable_dtype, importance=None, name="rand_1_gating",
    num_microbatches=None):
  """Compute a random top-1 gating."""
  # SELECT EXPERT
  if train:
    policy = hparams.moe_rand_1_policy_train
  else:
    policy = hparams.moe_rand_1_policy_eval

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  gate_inputs = mtf.to_float(inputs)

  # Input perturbations
  if train and policy == "input_dropout":
    gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_rand_1_dropout)
  elif train and policy == "input_jitter":
    gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
                                                   hparams.moe_rand_1_jitter)

  gate_logits = mtf.layers.dense(
      gate_inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter":
    expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim)
  elif policy == "sample":
    expert_index = mtf.sample_with_temperature(
        gate_logits, experts_dim, temperature=hparams.moe_rand_1_temperature)
    expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim)
  else:
    raise ValueError("Unknown rand_1 policy %s" % policy)

  expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype)

  # LOAD BALANCING LOSS
  # TODO(liamfedus): Check entropy loss.
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Logging
  if train:
    entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
                             reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # COMPUTE ASSIGNMENT TO EXPERT
  # Experts have a limited capacity, ensure we do not exceed it. Construct
  # the batch indices, to each expert, with position_in_expert
  position_in_expert = mtf.cumsum(
      expert_mask, group_size_dim, exclusive=True) * expert_mask
  position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype)
  # Keep only tokens that fit within expert_capacity.
  expert_capacity_float = float(expert_capacity_dim.size)
  expert_mask *= mtf.cast(
      mtf.less(position_in_expert, expert_capacity_float),
      dtype=raw_gates.dtype)
  expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim)

  # Mask out the experts that have overflowed expert capacity. Sparsify the
  # expert_gate.
  expert_gate *= expert_mask_flat

  combine_tensor = (
      expert_gate * expert_mask_flat *
      mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) *
      mtf.one_hot(
          mtf.to_int32(position_in_expert),
          expert_capacity_dim,
          dtype=raw_gates.dtype))

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss
Esempio n. 10
0
def _switch_gating(inputs,
                   outer_expert_dims,
                   experts_dim,
                   expert_capacity_dim,
                   hparams,
                   train,
                   variable_dtype,
                   importance=None,
                   name="switch_gating",
                   num_microbatches=None):
  """Compute a switch top-1 gating with no-token-left behind behavior."""
  # SELECT EXPERT
  if train:
    policy = hparams.moe_rand_1_policy_train
  else:
    policy = hparams.moe_rand_1_policy_eval

  # Input perturbations
  if train and policy == "input_jitter":
    inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter)

  gate_logits = mtf.layers.dense(
      inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  raw_gates = mtf.to_float(raw_gates)

  # Top-k operation
  k_dim = mtf.Dimension("k", hparams.moe_switch_top_k)
  expert_gate, expert_index = mtf.top_k(
      raw_gates, reduced_dim=experts_dim, k_dim=k_dim)
  expert_mask = mtf.one_hot(expert_index, experts_dim)

  # LOAD BALANCING LOSS
  outer_batch_dim = inputs.shape[0]
  batch_dim = inputs.shape[1]
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Logging
  if train:
    entropy = mtf.reduce_sum(
        -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # COMPUTE ASSIGNMENT TO EXPERT
  # Iteratively route tokens (no-token-left-behind). The idea is to route as
  # many tokens as possible to top-i before then trying top-(i+1).
  top_k_masks = mtf.split(
      expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_gates = mtf.split(
      expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_indices = mtf.split(
      expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size)

  # Tensors cumulative values over the iterative process.
  combine_tensor = mtf.constant(
      inputs.mesh,
      value=0,
      shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim])
  cum_tokens = mtf.constant(
      inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim])
  tokens_left_to_route = mtf.constant(
      inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim])

  expert_capacity_float = float(expert_capacity_dim.size)
  for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates,
                                                   top_k_indices):
    top_i_mask = mtf.reshape(
        top_i_mask,
        new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim])
    # Operate only on the unrouted tokens.
    top_i_mask *= tokens_left_to_route

    # Record cumulative number of tokens to each expert across iterations.
    cumulative_tokens_in_expert = cum_tokens + mtf.cumsum(
        top_i_mask, group_size_dim)

    expert_overflow = mtf.to_float(
        mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float))
    output_i_tokens = top_i_mask * expert_overflow

    # Update the cumulative tokens routed to each expert.
    cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim)
    tokens_left_to_route -= (
        mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim))

    # Combine-tensor for this iteration
    output_i_tokens_flat = mtf.reduce_sum(
        output_i_tokens, reduced_dim=experts_dim)
    position_in_expert = cumulative_tokens_in_expert - 1
    top_i_combine_tensor = (
        top_i_gate * output_i_tokens_flat *
        mtf.one_hot(top_i_index, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim))
    combine_tensor += top_i_combine_tensor

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss
Esempio n. 11
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    #
    # Begin simulation

    ## Compute initial initial conditions distributed
    #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    fieldvar = mtf.get_variable(mesh, 'linear', part_shape)
    input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    linearop = mtf.assign(fieldvar, mtfinp)

    #field = fieldvar
    initc = fieldvar

    print("initc : ", initc)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])
    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    print("mmy : ", mmy)
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    width = tf.placeholder(tf.float32, shape=())

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(width * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    #k_dims = [d.shape[0] for d in kv]
    #k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)
    #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype)
    #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype)
    #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    R0 = tf.placeholder(tf.float32, shape=())
    M0 = tf.placeholder(tf.float32, shape=())
    off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder(
        tf.float32, shape=data.shape)
    mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
    mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0)
    #diff = diff / 0.25
    #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one
    diff = (diff + mtfoff) / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype)
    var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
Esempio n. 12
0
def recon_model(mesh,
                datasm,
                rsdfactor,
                M0,
                R0,
                width,
                off,
                istd,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]

    #
    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))

    ##
    state = mtfpm.lpt_init_single(
        fieldvar,
        a0,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr,
                                     halo_size)

    final_field = mtf.zeros(mesh, shape=part_shape)
    final_field = mcomp.cic_paint_fr(final_field,
                                     final_state,
                                     output_shape=part_shape,
                                     hr_shape=hr_shape,
                                     halo_size=halo_size,
                                     splittables=splittables,
                                     mesh=mesh)

    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(tf.constant(width) * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    ##RSD below
    hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables,
                              mesh)
    mstate = mpm.mtf_indices(hr_field.mesh,
                             shape=part_shape[1:],
                             dtype=tf.float32)
    X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate],
                   output_shape=[batch_dim] + mstate.shape[:])

    massf = mesh_utils.r2c3d(final_field, k_dims, dtype=cdtype)
    masssmf = mtf.cwise(cwise_fingauss,
                        [massf, float_to_mtf(R1, mesh, scalar)] + kv +
                        [tfnc, tfbs],
                        output_dtype=cdtype)
    masssm = mesh_utils.c2r3d(masssmf, final_field.shape[-3:], dtype=dtype)
    masssm = masssm + 1e-5
    imasssm = mtf.pow(x, -1.)

    vzweights = final_state[1]
    vzweights = mtf.slicewise(lambda x: x[:, :, :, :, -1], [vzweights],
                              output_dtype=tf.float32,
                              output_shape=vzweights.shape[:-1],
                              name='get_vz',
                              splittable_dims=vzweights.shape[1:-1])
    print("weights : ", vzweights)

    momz = mtf.zeros(mesh, shape=part_shape)
    momz = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, \
                              halo_size=halo_size, splittables=splittables, mesh=mesh, weights=vzweights)
    momzf = mesh_utils.r2c3d(momz, k_dims, dtype=cdtype)
    momzsmf = mtf.cwise(cwise_fingauss,
                        [momzf, float_to_mtf(R1, mesh, scalar)] + kv +
                        [tfnc, tfbs],
                        output_dtype=cdtype)
    momzsm = mesh_utils.c2r3d(momzsmf, momz.shape[-3:], dtype=dtype)

    #Shift
    velzsm = mtf.divide(momzsm, masssm)
    vz = mcomp.cic_readout_fr(velzsm, [X],
                              hr_shape=hr_shape,
                              halo_size=halo_size,
                              splittables=splittables,
                              mesh=mesh)
    vz = mtf.multiply(vz, rsdfactor)
    print("vz : ", vz)

    Xrsd = mtf.slicewise(lambda x, vz: x + tf.stack(
        [tf.zeros_like(vz), tf.zeros_like(vz), vz], 4), [X, vzweights],
                         output_dtype=tf.float32,
                         output_shape=X.shape,
                         name='add_vz',
                         splittable_dims=X.shape[1:-1])
    print(Xrsd)
    modelread = mcomp.cic_readout_fr(model, [X],
                                     hr_shape=hr_shape,
                                     halo_size=halo_size,
                                     splittables=splittables,
                                     mesh=mesh)
    modelrsd = mtf.zeros(mesh, shape=part_shape)
    modelrsd = mcomp.cic_paint_fr(modelrsd, [Xrsd], output_shape=part_shape, hr_shape=hr_shape, \
                                  halo_size=halo_size, splittables=splittables, mesh=mesh, weights=modelread)

    model = modelrsd
    print(modelrsd)

    #Likelihood and prior here
    mtfdatasm = mtf.import_tf_tensor(mesh,
                                     tf.convert_to_tensor(datasm),
                                     shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    M0 = tf.constant(M0)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdatasm + M0)
    if off is not None:
        mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
        diff = diff + mtfoff
    if istd is not None:
        mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
        diff = (diff + mtfoff
                ) * mtfistd  #For some reason, doing things wrong this one
    else:
        diff = diff / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    fields = [fieldvar, final_field, model]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
Esempio n. 13
0
    'maxsig':
    lambda x: mtf.maximum(x, mtf.sigmoid(x)),
    'cosid':
    lambda x: mtf.cos(x) - x,
    'minsin':
    lambda x: mtf.minimum(x, mtf.sin(x)),
    'maxtanh':
    lambda x: mtf.maximum(x, mtf.tanh(x)),
    'mish':
    lambda x: x * mtf.tanh(mtf.softplus(x)),
    'tanhexp':
    lambda x: x * mtf.tanh(mtf.exp(x)),
    'lisht':
    lambda x: x * mtf.tanh(x),
    'seagull':
    lambda x: mtf.log(1 + x**2),
    'snake':
    lambda x: x + mtf.sin(x)**2,
    'roottanh':
    lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x),
    'softplusmone':
    lambda x: mtf.softplus(x) - 1
}


def get_activation_fn(params):
    if "activation_fn" in params:
        activation_fn = params["activation_fn"]
    else:
        print(
            "Defauling to GELU activation (see here: https://arxiv.org/abs/1606.08415)"
Esempio n. 14
0
    'maxsig':
    lambda x: mtf.maximum(x, mtf.sigmoid(x)),
    'cosid':
    lambda x: mtf.cos(x) - x,
    'minsin':
    lambda x: mtf.minimum(x, mtf.sin(x)),
    'maxtanh':
    lambda x: mtf.maximum(x, mtf.tanh(x)),
    'mish':
    lambda x: x * mtf.tanh(mtf.softplus(x)),
    'tanhexp':
    lambda x: x * mtf.tanh(mtf.exp(x)),
    'lisht':
    x * mtf.tanh(x),
    'seagull':
    mtf.log(1 + x**2),
    'snake':
    lambda x: x + mtf.sin(x)**2,
    'roottanh':
    lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x),
    'softplusmone':
    lambda x: mtf.softplus(x) - 1
}


def get_activation_fn(params):
    if "activation_fn" in params:
        activation_fn = params["activation_fn"]
    else:
        print(
            "Defauling to GELU activation (see here: https://arxiv.org/abs/1606.08415)"