示例#1
0
    def forward_with_state(self, xs, weights, state, rng):
        self._validate_forward_inputs(xs)
        (step, layers_state) = state
        # Get N+1 rngs, N for running layers and one extra.
        rngs = _split_rngs(rng, self._n_layers + 1)
        rng0, rngs = rngs[0], rngs[1:]
        if not self.sublayers:  # No-op: leave args unchanged.
            return (xs, (step + 1, layers_state))

        # Prepare the stack and do some safety checks as in the parent class.
        stack = xs
        new_state = []
        n_layers = self._n_layers
        if n_layers != 1 and len(weights) != n_layers:
            raise ValueError(
                'number of weights ({}) not equal to number of layers '
                '({})'.format(len(weights), n_layers))
        if n_layers != 1 and len(layers_state) != n_layers:
            raise ValueError(
                'length of state ({}) not equal to number of layers '
                '({})'.format(len(layers_state), n_layers))

        # TODO(chowdhery): try different strategies, also try running not all
        # layers backwards by using math.stop_gradient where needed.

        # Calculate how many layers to run forward.
        if self._mode == 'train':
            # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after
            w_steps = float(self._skipping_warmup_steps)
            warmup = np.maximum(0.0,
                                (w_steps - step.astype(np.float32)) / w_steps)
            # low is the minimum number of layers to *not* skip, from n_layers to 0
            low = warmup * float(n_layers)
            # high should be so that (high - n_layers) / high = 1.0 - skip_fraction
            # because (high - n_layers) / high is the probability we're not skipping
            # (after warmup); so high - n_layers = high - high * skip_fraction
            high = float(n_layers) / self._skip_fraction
            # We want the same rng0 on all cores.
            if math.device_count() > 1:
                rng0 = math.psum(rng0, 'batch')
            n_forward_layers = random.uniform(rng0, (), np.float32, low, high)
        else:
            n_forward_layers = float(n_layers)
        # Run layers skipping after a certain number.
        cur_layer_idx = 0.0
        for layer, p, s, rng in zip(self.sublayers, weights, layers_state,
                                    rngs):
            inputs = _inputs_from_stack(layer, stack)
            outputs, s = math.cond(  # Skip (do identity) if > n_forward_layers.
                pred=(math.lt(cur_layer_idx, n_forward_layers)),
                true_operand=(inputs, p, s, rng),  # This tuple is t below.
                true_fun=(lambda t: layer.pure_fn(t[0], t[1], t[2], t[3])),  # pylint: disable=cell-var-from-loop
                false_operand=(inputs, p, s, rng),
                false_fun=(lambda t: (t[0], t[2])),  # return (inputs, state)
            )
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_state.append(s)
            cur_layer_idx += 1.0
        return stack, (step + 1, new_state)
示例#2
0
    def forward_with_state(self,
                           xs,
                           weights=tl.EMPTY_WEIGHTS,
                           state=tl.EMPTY_STATE,
                           **kwargs):
        self._validate_forward_inputs(xs)
        # Get N+1 rngs, N for running layers and one extra.
        rngs = _pop_rng_and_split(kwargs, self._n_layers + 1)
        rng0, rngs = rngs[0], rngs[1:]
        if not self.sublayers:  # No-op: leave args unchanged.
            return (xs, state)

        # Prepare the stack and do some safety checks as in the parent class.
        stack = xs
        new_state = []
        n_layers = self._n_layers
        if n_layers != 1 and len(weights) != n_layers:
            raise ValueError(
                'number of weights ({}) not equal to number of layers '
                '({})'.format(len(weights), n_layers))
        if n_layers != 1 and len(state) != n_layers:
            raise ValueError(
                'length of state ({}) not equal to number of layers '
                '({})'.format(len(state), n_layers))

        # TODO(chowdhery): try different strategies, also try running not all
        # layers backwards by using math.stop_gradient where needed.

        # Calculate how many layers to run forward.
        if self._mode == 'train':
            n_forward_layers = random.uniform(rng0, (), np.float32, 0.0,
                                              n_layers)
        else:
            n_forward_layers = float(n_layers)
        # Run layers skipping after a certain number.
        cur_layer_idx = 0.0
        for layer, p, s, rng in zip(self.sublayers, weights, state, rngs):
            inputs = _inputs_from_stack(layer, stack)
            outputs, s = math.cond(  # Skip (do identity) if > n_forward_layers.
                pred=(math.lt(cur_layer_idx, n_forward_layers)),
                true_operand=(inputs, p, s, rng),  # This tuple is t below.
                true_fun=(lambda t: layer.pure_fn(t[0], t[1], t[2], t[3])),  # pylint: disable=cell-var-from-loop
                false_operand=(inputs, p, s, rng),
                false_fun=(lambda t: (t[0], t[2])),  # return (inputs, state)
            )
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_state.append(s)
            cur_layer_idx += 1.0
        return stack, new_state
示例#3
0
    def forward_with_state(self,
                           xs,
                           weights=tl.EMPTY_WEIGHTS,
                           state=tl.EMPTY_STATE,
                           **kwargs):
        self._validate_forward_inputs(xs)
        # Get N+1 rngs, N for running layers and one extra.
        rngs = _pop_rng_and_split(kwargs, self._n_layers + 1)
        rng0, rngs = rngs[0], rngs[1:]
        if not self.sublayers:  # No-op: leave args unchanged.
            return (xs, state)

        # Prepare the stack and do some safety checks as in the parent class.
        stack = xs
        new_state = []
        n_layers = self._n_layers
        if n_layers != 1 and len(weights) != n_layers:
            raise ValueError(
                'number of weights ({}) not equal to number of layers '
                '({})'.format(len(weights), n_layers))
        if n_layers != 1 and len(state) != n_layers:
            raise ValueError(
                'length of state ({}) not equal to number of layers '
                '({})'.format(len(state), n_layers))

        # TODO(chowdhery): try different strategies, also try running not all
        # layers backwards by using math.stop_gradient where needed.

        # Calculate how many layers to run forward.
        if self._mode == 'train':
            n_forward_layers = random.uniform(rng0, (), np.float32, 0.0,
                                              n_layers)
        else:
            n_forward_layers = float(n_layers)
        # Run layers skipping after a certain number.
        cur_layer_idx = 0.0
        for layer, p, s, rng in zip(self.sublayers, weights, state, rngs):
            inputs = _inputs_from_stack(layer, stack)
            # TODO(chowdhery): port to jax.lax.cond once it has a JVP rule.
            outputs, s = layer._forward_internal(inputs, p, s, rng)  # pylint: disable=protected-access
            condition = math.lt(cur_layer_idx,
                                n_forward_layers).astype(np.float32)
            outputs = condition * outputs + (1 - condition) * inputs
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_state.append(s)
            cur_layer_idx += 1.0
        return stack, new_state
示例#4
0
 def Init(shape, rng):
     """Returns random values for initializing weights of the given `shape`."""
     fan_in, fan_out = _GetFans(shape, out_dim, in_dim)
     gain = scale
     if mode == 'fan_in':
         gain /= fan_in
     elif mode == 'fan_out':
         gain /= fan_out
     elif mode == 'fan_avg':
         gain /= (fan_in + fan_out) / 2
     if distribution == 'truncated_normal':
         # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
         stddev = jnp.sqrt(gain) / .87962566103423978
         new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev
         return new_weights.astype('float32')
     elif distribution == 'normal':
         new_weights = random.normal(rng, shape) * jnp.sqrt(gain)
         return new_weights.astype('float32')
     elif distribution == 'uniform':
         lim = jnp.sqrt(3. * gain)
         return random.uniform(rng, shape, jnp.float32, -lim, lim)
     else:
         raise ValueError('invalid distribution for ScaleInitializer')
示例#5
0
def RandomUniformInitializer(lim=1.0):
    """Returns an initializer for random uniform coefficients."""
    return lambda shape, rng: random.uniform(rng, shape, jnp.float32, -lim, lim
                                             )
示例#6
0
def AtariConvInit(kernel_shape, rng, dtype=jnp.float32):
    """The standard init for Conv laters and Atari."""
    filter_height, filter_width, fan_in, _ = kernel_shape
    std = 1 / jnp.sqrt(fan_in * filter_height * filter_width)
    return random.uniform(rng, kernel_shape, dtype, minval=-std, maxval=std)
示例#7
0
def RandomUniformInitializer(lim=1.0):
    """Returns an initializer for random uniform coefficients."""
    # Make sure shape does not contain int tensors by calling int() below.
    return lambda shape, rng: random.uniform(  # pylint: disable=g-long-lambda
        rng, _PureShape(shape), jnp.float32, -lim, lim)