예제 #1
0
 def Initializer(shape, rng):
     del rng
     logging.info('Loading pretrained embeddings from %s', path)
     with tf.io.gfile.GFile(path, 'rb') as f:
         parameters = jnp.load(f)
     assert jnp.shape(parameters) == shape, ('Expected shape %s, got %s' %
                                             (shape, jnp.shape(parameters)))
     return parameters
예제 #2
0
파일: training.py 프로젝트: RJK99/trax
  def log_summary(self, values, summary_writer, value_prefix, log_prefix,
                  stdout=True):
    """Logs and saves provided metrics.

    Args:
      values: Dict from metric name to metric value.
      summary_writer: Jaxboard summary writer.
      value_prefix: String appended in front of summary_writer entries.
      log_prefix: String appended in front of logs.
      stdout: Boolean saying if logs should be logged to stdout as well.
    """
    history = self._history
    should_write_summaries = self.is_chief and summary_writer is not None
    for name, value in values.items():
      full_name = value_prefix + name
      s = tuple(jnp.shape(value))
      if not s:
        self._log_step(
            '%s %s | % .8f' %
            (log_prefix.ljust(5), name.rjust(self._rjust_len), value),
            stdout=stdout)
        if should_write_summaries:
          summary_writer.scalar(full_name, value, self.step)
      else:
        if should_write_summaries:
          summary_writer.image(full_name, value, self.step)
      if history:
        history.append(log_prefix, full_name, self.step, value)
    if should_write_summaries:
      summary_writer.flush()
예제 #3
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
                px = self.weights[:, :symbol_size, :]
            else:
                rng1, rng2 = fastmath.random.split(self.rng, 2)
                start = fastmath.random.randint(rng1, (), 0,
                                                self._max_offset_to_add)
                start_from_zero = fastmath.random.uniform(
                    rng2, (), jnp.float32, 0, 1)
                start = jnp.where(start_from_zero < self._start_from_zero_prob,
                                  jnp.zeros((), dtype=jnp.int32), start)
                px = fastmath.dynamic_slice_in_dim(self.weights,
                                                   start,
                                                   symbol_size,
                                                   axis=1)
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        fastmath.dynamic_slice_in_dim(self.weights[0],
                                                      state[i],
                                                      inputs.shape[1],
                                                      axis=0))
                self.state = state + inputs.shape[1]
                res = inputs + jnp.stack(emb, 0)
                return res
예제 #4
0
파일: attention.py 프로젝트: google/trax
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        weights = self.weights
        if self._d_feature is not None:
            weights, ff = weights
            weights = jnp.dot(weights[:inputs.shape[1], :], ff)
        if len(weights.shape
               ) < 3:  # old checkpoints have 1 in first dim already
            weights = weights[None, :, :]  # [1, self._max_len, d_feature]
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
                px = weights[:, :symbol_size, :]
            else:
                rng1, rng2 = fastmath.random.split(self.rng, 2)
                start = fastmath.random.randint(rng1, (), 0,
                                                self._max_offset_to_add)
                start_from_zero = fastmath.random.uniform(
                    rng2, (), jnp.float32, 0, 1)
                start = jnp.where(start_from_zero < self._start_from_zero_prob,
                                  jnp.zeros((), dtype=jnp.int32), start)
                px = fastmath.dynamic_slice_in_dim(weights,
                                                   start,
                                                   symbol_size,
                                                   axis=1)
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer stores the index of the current
            # position and increments it on each call.
            emb = fastmath.dynamic_slice_in_dim(weights,
                                                self.state,
                                                inputs.shape[1],
                                                axis=1)
            self.state += inputs.shape[1]
            return inputs + emb
예제 #5
0
    def forward(self, inputs):
        gamma, beta, epsilon_l = self.weights

        epsilon = self._init_epsilon
        if epsilon_l is not base.EMPTY_WEIGHTS:
            epsilon += jnp.abs(epsilon_l[0])

        # Omit B and C
        axis = tuple(range(1, len(jnp.shape(inputs)) - 1))
        # (B, 1, 1, C)
        nu2 = jnp.mean(inputs**2, axis=axis, keepdims=True)
        # (B, W, H, C)
        xhat = inputs / jnp.sqrt(nu2 + epsilon)

        return gamma * xhat + beta
예제 #6
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            px = self.weights[:, :symbol_size, :]
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                if fastmath.is_backend(fastmath.Backend.JAX):
                    keep_prob = jax.lax.tie_in(
                        x, jnp.full((), keep_prob, dtype=x.dtype))
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        jax.lax.dynamic_slice_in_dim(self.weights[0],
                                                     state[i],
                                                     inputs.shape[1],
                                                     axis=0))
                self.state = state + inputs.shape[1]
                return inputs + jnp.stack(emb, 0)
예제 #7
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            length = jnp.shape(x)[1]
            if self._mode != 'train':
                start = 0
            else:
                rng1, rng2 = fastmath.random.split(self.rng, 2)
                start = fastmath.random.randint(rng1, (), 0, self._add_offset)
                start_from_nonzero = fastmath.random.randint(
                    rng2, (), 0, self._start_from_zero_one_in)
                start_from_nonzero = jnp.minimum(1, start_from_nonzero)
                start *= start_from_nonzero
            px = self._sincos(start, length, inputs.shape[2])
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            pe = self._sincos(self.state, inputs.shape[1], inputs.shape[2])
            self.state += inputs.shape[1]
            return inputs + pe