예제 #1
0
  def __call__(
      self,
      inputs,
      state: LSTMState,
  ) -> Tuple[jnp.ndarray, LSTMState]:
    input_to_hidden = hk.ConvND(
        num_spatial_dims=self.num_spatial_dims,
        output_channels=4 * self.output_channels,
        kernel_shape=self.kernel_shape,
        name="input_to_hidden")

    hidden_to_hidden = hk.ConvND(
        num_spatial_dims=self.num_spatial_dims,
        output_channels=4 * self.output_channels,
        kernel_shape=self.kernel_shape,
        name="hidden_to_hidden")

    gates = input_to_hidden(inputs) + hidden_to_hidden(state.hidden)
    i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1)

    f = jax.nn.sigmoid(f + 1)
    c = f * state.cell + jax.nn.sigmoid(i) * jnp.tanh(g)
    h = jax.nn.sigmoid(o) * jnp.tanh(c)
    return h, LSTMState(h, c)
예제 #2
0
    def __call__(self, x, alpha):
        if alpha is None:
            raise ValueError('alpha must be specified.')
        if self.num_freqs == 0:
            return x

        num_channels = x.shape[-1]

        base_encoder = SinusoidalEncoder(num_freqs=self.num_freqs,
                                         max_freq_log2=self.max_freq_log2,
                                         scale=self.scale)
        features = base_encoder(x)
        identity, features = jnp.split(features, (x.shape[-1], ), axis=-1)

        # Apply the window by broadcasting to save on memory.
        features = jnp.reshape(features, (-1, 2, num_channels))
        window = self.cosine_easing_window(self.num_freqs, alpha)
        window = jnp.reshape(window, (-1, 1, 1))
        features = window * features

        return jnp.concatenate([
            identity,
            features.flatten(),
        ], axis=-1)
예제 #3
0
파일: recurrent.py 프로젝트: varung/flax
    def __call__(self, carry, inputs):
        """Constructs a convolutional LSTM.

    Args:
      carry: the hidden state of the Conv2DLSTM cell,
        initialized using `Conv2DLSTM.initialize_carry`.
      inputs: input data with dimensions (batch, spatial_dims..., features).
    Returns:
      A tuple with the new carry and the output.
    """
        c, h = carry
        input_to_hidden = partial(Conv,
                                  features=4 * self.features,
                                  kernel_size=self.kernel_size,
                                  strides=self.strides,
                                  padding=self.padding,
                                  use_bias=self.use_bias,
                                  dtype=self.dtype,
                                  name='ih')

        hidden_to_hidden = partial(Conv,
                                   features=4 * self.features,
                                   kernel_size=self.kernel_size,
                                   strides=self.strides,
                                   padding=self.padding,
                                   use_bias=self.use_bias,
                                   dtype=self.dtype,
                                   name='hh')

        gates = input_to_hidden()(inputs) + hidden_to_hidden()(h)
        i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1)

        f = sigmoid(f + 1)
        new_c = f * c + sigmoid(i) * jnp.tanh(g)
        new_h = sigmoid(o) * jnp.tanh(new_c)
        return (new_c, new_h), new_h
예제 #4
0
    def __call__(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        sincos = jnp.take(self.embed_positions, position_ids, axis=0)
        sincos = jnp.split(sincos, 2, axis=-1)
        if self.rotary_dim is not None:
            k_rot = key[:, :, :, :self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim:]

            q_rot = query[:, :, :, :self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim:]

            k_rot = apply_rotary_pos_emb(k_rot, sincos)
            q_rot = apply_rotary_pos_emb(q_rot, sincos)

            key = jnp.concatenate([k_rot, k_pass], axis=-1)
            query = jnp.concatenate([q_rot, q_pass], axis=-1)
        else:
            key = apply_rotary_pos_emb(key, sincos)
            query = apply_rotary_pos_emb(query, sincos)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
        )

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
예제 #5
0
 def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng):
   onp_fun = lambda x: onp.split(x, num_sections, axis=axis)
   lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis)
   args_maker = lambda: [rng(shape, dtype)]
   self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
   self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
 def split(batched_values):
   return [
       jnp.squeeze(v) for v in jnp.split(
           batched_values, indices_or_sections=b1, axis=0)
   ]
 def f(x):
     return np.split(x, 2)
예제 #8
0
파일: functions.py 프로젝트: sts-sadr/jax
def glu(x, axis=-1):
  """Gated linear unit activation function."""
  size = x.shape[axis]
  assert size % 2 == 0, "axis size must be divisible by 2"
  x1, x2 = jnp.split(x, 2, axis)
  return x1 * sigmoid(x2)
예제 #9
0
    def __call__(self, input_qkv):
        cfg = self.config
        log_len = log_2_ceil(cfg.max_len - 1)
        bsize = input_qkv.shape[0]
        features = self.out_features or input_qkv.shape[-1]
        query, key, value, head_dim = get_qkv(cfg, input_qkv)

        joint_logits = []
        list_vals = []
        for l in range(log_len):
            ctx_len = 2**l
            last_pos = cfg.max_len - cfg.max_len % ctx_len
            num_ctx = cfg.max_len // ctx_len

            if l == 0:
                span_key = jnp.reshape(key, [-1, 1, cfg.num_heads, head_dim])
                span_val = value.reshape(span_key.shape)
                self_logits = jnp.expand_dims(jnp.sum(query * key, axis=-1),
                                              -1)
                joint_logits.append(self_logits)
            else:
                left_query = query[:, :last_pos, :, :].reshape(
                    [-1, ctx_len, cfg.num_heads, head_dim])
                span_query = jnp.max(left_query, axis=1, keepdims=True)
                left_key = key[:, :last_pos, :, :].reshape(left_query.shape)
                left_val = value[:, :last_pos, :, :].reshape(left_query.shape)
                span_val = dot_product_attention(
                    span_query * jnp.sqrt(head_dim),
                    left_key,
                    left_val,
                    dropout_rng=self.get_dropout_png(cfg),
                    dropout_rate=cfg.attention_dropout_rate,
                    broadcast_dropout=False,
                    deterministic=cfg.deterministic,
                    dtype=cfg.dtype)
                span_key = jnp.max(left_key, axis=1, keepdims=True)
            rolled_q = jnp.roll(query, -ctx_len,
                                axis=1)[:, :last_pos, :, :].reshape(
                                    [-1, ctx_len, cfg.num_heads, head_dim])

            rolled_mask = jnp.concatenate(
                [(jnp.arange(cfg.max_len - ctx_len) // ctx_len) % 2,
                 jnp.ones(last_pos + ctx_len - cfg.max_len, dtype=jnp.int32)],
                axis=0)
            rolled_mask = jnp.reshape(rolled_mask, [1, -1, 1, 1])
            rolled_logits = jnp.einsum('...qhd,...khd->...qhk', rolled_q,
                                       span_key)
            # bsize, last_pos, h, 1
            rolled_logits = jnp.reshape(
                rolled_logits, [bsize, -1, cfg.num_heads, 1
                                ]) + rolled_mask.astype(rolled_q.dtype) * -1e9
            orig_logits = jnp.pad(rolled_logits, [(0, 0),
                                                  (0, cfg.max_len - last_pos),
                                                  (0, 0), (0, 0)],
                                  constant_values=-1e9)
            orig_logits = jnp.roll(orig_logits, ctx_len, axis=1)
            joint_logits.append(orig_logits)
            list_vals.append(span_val)
        joint_logits = jnp.concatenate(joint_logits, axis=-1)
        attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype)
        local_weights = jnp.split(attn_weights, log_len + 1, axis=-1)
        local_weighted_sums = []
        joint_merged = local_weights[0] * value
        for l in range(log_len):
            ctx_len = 2**l
            last_pos = cfg.max_len - cfg.max_len % ctx_len
            num_ctx = cfg.max_len // ctx_len

            rolled_w = jnp.roll(local_weights[l + 1], -ctx_len,
                                axis=1)[:, :last_pos, :, :].reshape(
                                    bsize * num_ctx, ctx_len, cfg.num_heads, 1)
            rolled_v = jnp.reshape(rolled_w * list_vals[l],
                                   [bsize, -1, cfg.num_heads, head_dim])
            rolled_v = jnp.pad(rolled_v, [(0, 0), (0, cfg.max_len - last_pos),
                                          (0, 0), (0, 0)])
            orig_v = jnp.roll(rolled_v, ctx_len, axis=1)
            joint_merged = joint_merged + orig_v
        x = DenseGeneral(features=features,
                         axis=(-2, -1),
                         kernel_init=cfg.kernel_init,
                         bias_init=cfg.bias_init,
                         use_bias=False,
                         dtype=cfg.dtype)(joint_merged)
        return x
예제 #10
0
def XDgroup_data(JD_time, JDs, pol, chans=None, tints=None, bad_ants=True, \
                 use_flags='first', noise=False, use_cal=None, verbose=False):
    """Returns redundant baseline grouping and reformatted dataset, with
    external flags applied, if specified

    :param JD_time: Julian time of 1st dataset, which sets times for others
    :type JD_time: str
    :param JDs: Julian days of data
    :type JDs: list, ndarray
    :param pol: Polarization of data
    :type pol: str
    :param chans: Frequency channel(s) {0, 1023} (None to choose all)
    :type chans: array-like, int, or None
    :param tints: Time integrations {0, 59} (None to choose all)
    :type tints: array-like, int, or None
    :param bad_ants: Flag known bad antennas, optional
    :type bad_ants: bool
    :param use_flags: Use flags to mask data
    :type use_flags: str
    :param noise: Also calculate noise from autocorrelations
    :type noise: bool
    :param use_cal: calfits file extension to use to calibrate data
    :type use_cal: str, None
    :param verbose: Print data gathering steps for each dataset
    :type verbose: bool

    :return hd: HERAData class
    :rtype hd: HERAData class
    :return redg: Grouped baselines, as returned by groupBls
    :rtype redg: ndarray
    :return cdata: Grouped visibilities with flags in numpy MaskedArray format,
    with format consistent with redg and dimensions (freq chans,
    time integrations, baselines)
    :rtype cdata: MaskedArray
    """

    if isinstance(chans, int):
        chans = np.asarray([chans])
    if isinstance(tints, int):
        tints = np.asarray([tints])

    zen_fn = find_zen_file(JD_time)
    flags_fn = find_flag_file(JD_time, use_flags)

    hd = HERAData(zen_fn)
    if tints is None:
        tints = np.arange(hd.Ntimes)

    if bad_ants:
        bad_ants = union_bad_ants(JDs)
    else:
        bad_ants = None

    if use_cal is None:
        cal_path = None
    else:
        cal_path = find_flag_file(JD_time, use_cal)

    if not verbose:
        grp_data = suppressOutput(group_data)
    else:
        grp_data = group_data

    grp = grp_data(zen_fn,
                   pol,
                   chans=chans,
                   tints=tints,
                   bad_ants=bad_ants,
                   flag_path=flags_fn,
                   noise=noise,
                   cal_path=cal_path)
    _, redg, cMData = grp[:3]

    cMData = cMData[np.newaxis, :]
    if noise:
        cNoise = grp[3]
        cNoise = cNoise[np.newaxis, :]

    JD_day = int(float(JD_time))
    if JD_day in JDs:
        JDs = list(JDs)
        JDs.remove(JD_day)

    for jd_i in JDs:
        JD_time_ia = match_lst(JD_time, jd_i)
        # aligning datasets in LAST
        last_df = pd.read_pickle(
            os.path.join(os.path.dirname(__file__), 'jd_lst_map_idr2.pkl'))
        last1 = last_df[last_df['JD_time'] == float(
            JD_time)]['LASTs'].values[0]
        last2 = last_df[last_df['JD_time'] == float(
            JD_time_ia)]['LASTs'].values[0]
        _, offset = find_nearest(last2, last1[0])
        tints_i = (tints + offset) % 60
        scnd_dataset = all(tints + offset > hd.Ntimes - 1)
        single_dataset = all(tints + offset < hd.Ntimes - 1) or scnd_dataset

        if not single_dataset:
            tints_ia, tints_ib = np.split(tints_i, np.where(tints_i == 0)[0])
        else:
            tints_ia = tints_i

        if scnd_dataset:
            next_row = numpy.where(
                last_df['JD_time'] == float(JD_time_ia))[0][0] + 1
            JD_time_ib = last_df.iloc[next_row]['JD_time']
            JD_time_ia = JD_time_ib

        JD_time_ia = check_jdt(JD_time_ia)
        zen_fn_ia = find_zen_file(JD_time_ia)
        flags_fn_ia = find_flag_file(JD_time_ia, use_flags)
        if use_cal is not None:
            cal_path_ia = find_flag_file(JD_time_ia, use_cal)
        else:
            cal_path_ia = None
        grp_a = grp_data(zen_fn_ia, pol, chans=chans, tints=tints_ia, \
                         bad_ants=bad_ants, flag_path=flags_fn_ia, noise=noise, \
                         cal_path=cal_path_ia)
        cMData_ia = grp_a[2]

        if not single_dataset:
            next_row = numpy.where(
                last_df['JD_time'] == float(JD_time_ia))[0][0] + 1
            JD_time_ib = last_df.iloc[next_row]['JD_time']
            JD_time_ib = check_jdt(JD_time_ib)
            zen_fn_ib = find_zen_file(JD_time_ib)
            flags_fn_ib = find_flag_file(JD_time_ib, use_flags)
            if use_cal is not None:
                cal_path_ib = find_flag_file(JD_time_ib, use_cal)
            else:
                cal_path_ib = None
            grp_b = grp_data(zen_fn_ib, pol, chans=chans, tints=tints_ib, \
                             bad_ants=bad_ants, flag_path=flags_fn_ib, \
                             noise=noise, cal_path=cal_path_ib)
            cMData_ib = grp_b[2]

            cMData_i = numpy.ma.concatenate((cMData_ia, cMData_ib), axis=1)
        else:
            cMData_i = cMData_ia

        cMData_i = cMData_i[np.newaxis, :]
        cMData = numpy.ma.concatenate((cMData, cMData_i), axis=0)

        if noise:
            cNoise_ia = grp_a[3]
            if not single_dataset:
                cNoise_ib = grp_b[3]
                cNoise_i = np.concatenate((cNoise_ia, cNoise_ib), axis=1)
            else:
                cNoise_i = cNoise_ia
            cNoise_i = cNoise_i[np.newaxis, :]
            cNoise = np.concatenate((cNoise, cNoise_i), axis=0)

    if noise:
        return hd, redg, cMData, cNoise
    else:
        return hd, redg, cMData
예제 #11
0
def split(ary, indices_or_sections, axis=0):
  if isinstance(ary, JaxArray): ary = ary.value
  if isinstance(indices_or_sections, JaxArray): indices_or_sections = indices_or_sections.value
  return [JaxArray(a) for a in jnp.split(ary, indices_or_sections, axis=axis)]
예제 #12
0
    def apply(self,
              x,
              batch_stats=None,
              use_running_average=False,
              axis=-1,
              momentum=0.99,
              epsilon=1e-5,
              dtype=jnp.float32,
              bias=True,
              scale=True,
              bias_init=initializers.zeros,
              scale_init=initializers.ones,
              axis_name=None,
              axis_index_groups=None):
        """Normalizes the input using batch statistics.

    Args:
      x: the input to be normalized.
      batch_stats: a `flax.nn.Collection` used to store an exponential moving
        average of the batch statistics (default: None).
      use_running_average: if true, the statistics stored in batch_stats
        will be used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of
        the batch statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      bias:  if True, bias (beta) is added.
      scale: if True, multiply by scale (gamma).
        When the next layer is linear (also e.g. nn.relu), this can be disabled
        since the scaling will be done by the next layer.
      bias_init: initializer for bias, by default, zero.
      scale_init: initializer for scale, by default, one.
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).
      axis_index_groups: groups of axis indices within that named axis
        representing subsets of devices to reduce over (default: None). For example,
        `[[0, 1], [2, 3]]` would independently batch-normalize over the examples
        on the first two and last two devices. See `jax.lax.psum` for more details.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        x = jnp.asarray(x, jnp.float32)
        axis = axis if isinstance(axis, tuple) else (axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
        if self.is_stateful() or batch_stats:
            ra_mean = self.state('mean',
                                 reduced_feature_shape,
                                 initializers.zeros,
                                 collection=batch_stats)
            ra_var = self.state('var',
                                reduced_feature_shape,
                                initializers.ones,
                                collection=batch_stats)
        else:
            ra_mean = None
            ra_var = None

        if use_running_average:
            if ra_mean is None:
                raise ValueError(
                    'when use_running_averages is True '
                    'either use a stateful context or provide batch_stats')
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if axis_name is not None and not self.is_initializing():
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=axis_name,
                              axis_index_groups=axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if ra_mean and not self.is_initializing():
                ra_mean.value = momentum * ra_mean.value + (1 -
                                                            momentum) * mean
                ra_var.value = momentum * ra_var.value + (1 - momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + epsilon)
        if scale:
            mul = mul * self.param('scale', reduced_feature_shape,
                                   scale_init).reshape(feature_shape)
        y = y * mul
        if bias:
            y = y + self.param('bias', reduced_feature_shape,
                               bias_init).reshape(feature_shape)
        return jnp.asarray(y, dtype)
예제 #13
0
    def testTupleCGA(self, fullmatrix, conj_grad, order, amat):
        if fullmatrix:
            self.skipTest(
                ("PyTree inputs are not supported by the full-matrix "
                 "implementation of CGA."))

        amat = amat + np.eye(*amat.shape)

        def f(x, y):
            return x.T @ amat @ y + np.dot(y, y)

        def g(x, y):
            return -f(x, y)

        def tuple_f(x, y):
            assert isinstance(x, tuple)
            assert isinstance(y, tuple)
            x = np.concatenate(x)
            y = _tree_concatentate(y)
            return f(x, y)

        def tuple_g(x, y):
            assert isinstance(x, tuple)
            assert isinstance(y, tuple)
            x = np.concatenate(x)
            y = _tree_concatentate(y)
            return g(x, y)

        eta = 0.1
        rtol = atol = 1e-8
        max_iter = 3000

        def convergence_test(x_new, x_old):
            return converge.max_diff_test(x_new, x_old, rtol, atol)

        linear_op_solver = None
        if conj_grad:
            linear_op_solver = cga.cg_fixed_point_solve

        # The hypothesis package takes care of setting the seed of python's
        # random package.
        rng = random.PRNGKey(pyrandom.randint(0, 2**32 - 1))
        rng_x, rng_y = random.split(rng)

        init_vals = (random.uniform(rng_x, shape=(amat.shape[0], )),
                     random.uniform(rng_y, shape=(amat.shape[1], )))

        tuple_y = np.split(init_vals[1], (1, ))
        tuple_y = (tuple_y[0], tuple(np.split(tuple_y[1], (1, ))))
        init_tuple_vals = (tuple(np.split(init_vals[0], (1, ))), tuple_y)

        tuple_sol = cga.cga_iteration(init_tuple_vals,
                                      tuple_f,
                                      tuple_g,
                                      convergence_test,
                                      max_iter,
                                      eta,
                                      use_full_matrix=fullmatrix,
                                      linear_op_solver=linear_op_solver,
                                      solve_order=order)

        solution = cga.cga_iteration(init_vals,
                                     f,
                                     g,
                                     convergence_test,
                                     max_iter,
                                     eta,
                                     use_full_matrix=fullmatrix,
                                     linear_op_solver=linear_op_solver,
                                     solve_order=order)

        # check if tuple type is preserved
        self.assertTrue(isinstance(tuple_sol[0], tuple))
        self.assertTrue(isinstance(tuple_sol[1], tuple))

        # check if tuple type is preserved
        self.assertTrue(isinstance(tuple_sol, tuple))

        # check if output is the same for tuple inputs vs a single array
        self.assertAllClose(
            tuple((_tree_concatentate(x) for x in tuple_sol)),
            solution,
            check_dtypes=True,
            rtol=1e-8,
            atol=1e-8,
        )
예제 #14
0
 def dist_fn(raw_params):
     mus, raw_log_vars = jnp.split(raw_params, 2)
     vars = jnp.exp(raw_log_vars) + min_scale_diag
     return tfd.MultivariateNormalDiag(loc=mus, scale_diag=jnp.sqrt(vars))
예제 #15
0
    def par_from_array(arr):
        value_flat = jnp.split(arr, section_sizes)
        value_flat = [x.reshape(s) for x, s in zip(value_flat, section_shapes)]

        params = tree_unflatten(value_tree, value_flat)
        return params
예제 #16
0
    def __call__(self, input_qkv):
        cfg = self.config
        cfg.max_len % cfg.max_seg_len == 0

        assert input_qkv.ndim == 3
        bsize = input_qkv.shape[0]
        features = self.out_features or input_qkv.shape[-1]
        qkv_features = cfg.qkv_dim or input_qkv.shape[-1]
        assert qkv_features % cfg.num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // cfg.num_heads

        dense = partial(DenseGeneral,
                        axis=-1,
                        features=(cfg.num_heads, head_dim),
                        kernel_init=cfg.kernel_init,
                        bias_init=cfg.bias_init,
                        use_bias=False)
        query, key, value = (dense(dtype=cfg.dtype, name='query')(input_qkv) /
                             jnp.sqrt(head_dim),
                             dense(dtype=cfg.dtype, name='key')(input_qkv),
                             dense(dtype=cfg.dtype, name='value')(input_qkv))
        num_seg = cfg.max_len // cfg.max_seg_len

        ##################
        cur_query = query.reshape(
            [bsize, num_seg, cfg.max_seg_len, cfg.num_heads, head_dim])
        cur_key = key.reshape(
            [bsize, num_seg, cfg.max_seg_len, cfg.num_heads, head_dim])
        cur_value = value.reshape(
            [bsize, num_seg, cfg.max_seg_len, cfg.num_heads, head_dim])

        num_attn_dims = 2
        col_logit_expr = 'BSUNK,BTUNK->BUNST'
        col_attn_expr = 'BUNST,BTUNK->BSUNK'
        col_strict_mask = make_causal_mask(
            cur_query, length_axis=1, strict=True
        )  # strict lower triangular matrix so that the token won't repeatedly attend to itself
        col_strict_mask = jnp.expand_dims(col_strict_mask, axis=1)
        # (bsize, 1, 1, num_seg, num_seg)
        col_strict_bias = lax.select(
            col_strict_mask > 0,
            jnp.full(col_strict_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(col_strict_mask.shape, -1e10).astype(cfg.dtype))

        row_logit_expr = 'BUSNK,BUTNK->BUNST'
        row_attn_expr = 'BUNST,BUTNK->BUSNK'
        row_mask = make_causal_mask(cur_query, length_axis=2)[:, 0:1, :, :, :]
        # (bsize, 1, 1, max_seg_len, max_seg_len)
        row_bias = lax.select(
            row_mask > 0,
            jnp.full(row_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(row_mask.shape, -1e10).astype(cfg.dtype))

        col_logits = jnp.einsum(col_logit_expr, cur_query,
                                cur_key) + col_strict_bias
        # (bsize, max_seg_len, num_head, num_seg, num_seg)
        row_logits = jnp.einsum(row_logit_expr, cur_query, cur_key) + row_bias
        # (bsize, num_seg, num_head, max_seg_len, max_seg_len)
        ###############################

        col_up2down_query = jax.lax.cummax(cur_query, axis=1)
        col_up2down_key = shift_right(jax.lax.cummax(cur_key, axis=1),
                                      axis=1)  # shift down in some sense
        col_mask = make_causal_mask(cur_query, length_axis=1)
        col_mask = jnp.expand_dims(col_mask, axis=1)
        col_bias = lax.select(
            col_mask > 0,
            jnp.full(col_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(col_mask.shape, -1e10).astype(cfg.dtype))
        col_up2down_logits = jnp.einsum(col_logit_expr, col_up2down_query,
                                        cur_key) + col_bias
        col_up2down_attn_weights = jax.nn.softmax(col_up2down_logits).astype(
            cfg.dtype)
        col_up2down_summary = jnp.einsum(col_attn_expr,
                                         col_up2down_attn_weights, cur_value)
        col_up2down_summary = shift_right(col_up2down_summary,
                                          axis=1)  # shift down in some sense

        row_only_myself_mask = jnp.expand_dims(jnp.eye(cur_query.shape[2]),
                                               (0, 1, 2))
        row_without_myself_bias = lax.select(
            row_only_myself_mask == 0,
            jnp.full(row_only_myself_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(row_only_myself_mask.shape, -1e10).astype(cfg.dtype)
        )  # attend to all tokens in the previous row except for the token right up to the token in the previous token because this is already taken care of in the local col attention
        all_maskout = jnp.full(row_only_myself_mask.shape,
                               -1e10).astype(cfg.dtype)
        row_without_myself_bias = jnp.concatenate(
            [all_maskout] + [row_without_myself_bias] *
            (cur_query.shape[1] - 1),
            axis=1
        )  # the first row also has no previous row to attend, so just mask out all logits calculated here
        previous_row_logits = jnp.einsum(
            row_logit_expr, cur_query,
            col_up2down_key) + row_without_myself_bias

        row_left2right_query = jax.lax.cummax(cur_query, axis=2)
        row_left2right_key = shift_right(jax.lax.cummax(cur_key, axis=2),
                                         axis=2)
        row_left2right_logits = jnp.einsum(
            row_logit_expr, row_left2right_query, cur_key) + row_bias
        row_left2right_attn_weights = jax.nn.softmax(
            row_left2right_logits).astype(cfg.dtype)
        row_left2right_summary = jnp.einsum(row_attn_expr,
                                            row_left2right_attn_weights,
                                            cur_value)
        row_left2right_summary = shift_right(row_left2right_summary, axis=2)

        all_maskout = jnp.full(col_strict_bias.shape, -1e10).astype(cfg.dtype)
        col_strict_without_first_bias = jnp.concatenate(
            [all_maskout] + [col_strict_bias] * (cur_query.shape[2] - 1),
            axis=1)
        top_left_col_logits = jnp.einsum(
            col_logit_expr, cur_query,
            row_left2right_key) + col_strict_without_first_bias
        ##################################
        row_right2left_query = jax.lax.cummax(cur_query, axis=2, reverse=True)
        row_right2left_key = shift_left(jax.lax.cummax(cur_key,
                                                       axis=2,
                                                       reverse=True),
                                        axis=2)
        row_strict_mask = make_causal_mask(cur_query,
                                           length_axis=2,
                                           strict=True)[:, 0:1, :, :, :]
        # (bsize, 1, 1, max_seg_len, max_seg_len)
        row_upper_bias = lax.select(
            row_strict_mask == 0,
            jnp.full(row_strict_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(row_strict_mask.shape, -1e10).astype(cfg.dtype)
        )  # an upper triangular matrix since we attend all tokens on the right
        row_right2left_logits = jnp.einsum(
            row_logit_expr, row_right2left_query, cur_key) + row_upper_bias
        row_right2left_attn_weights = jax.nn.softmax(
            row_right2left_logits).astype(cfg.dtype)
        row_right2left_summary = jnp.einsum(row_attn_expr,
                                            row_right2left_attn_weights,
                                            cur_value)
        row_right2left_summary = shift_left(row_right2left_summary, axis=2)

        col_strict_without_last_bias = jnp.concatenate(
            [col_strict_bias] * (cur_query.shape[2] - 1) + [all_maskout],
            axis=1)
        top_right_col_logits = jnp.einsum(
            col_logit_expr, cur_query,
            row_right2left_key) + col_strict_without_last_bias
        ####

        joint_logits = jnp.concatenate(
            (col_logits.transpose([0, 3, 2, 1, 4]), row_logits,
             previous_row_logits, top_left_col_logits.transpose([
                 0, 3, 2, 1, 4
             ]), top_right_col_logits.transpose([0, 3, 2, 1, 4])),
            axis=-1
        )  # follow row, row first, the shape should be (bsize, num_seg, num_head, max_seg_len, num_seg+max_seg_len+max_seg_len+num_seg+num_seg)
        attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype)

        col_att, row_att, previous_row_att, top_left_col_att, top_right_col_att = jnp.split(
            attn_weights, [
                num_seg, num_seg + cfg.max_seg_len, num_seg +
                cfg.max_seg_len * 2, num_seg * 2 + cfg.max_seg_len * 2
            ],
            axis=-1)
        col_att = col_att.transpose([0, 3, 2, 1, 4])
        top_left_col_att = top_left_col_att.transpose([0, 3, 2, 1, 4])
        top_right_col_att = top_right_col_att.transpose([0, 3, 2, 1, 4])
        col_merged = jnp.einsum(col_attn_expr, col_att, cur_value)
        row_merged = jnp.einsum(row_attn_expr, row_att, cur_value)
        previous_row_merged = jnp.einsum(row_attn_expr, previous_row_att,
                                         col_up2down_summary)
        top_left_merged = jnp.einsum(col_attn_expr, top_left_col_att,
                                     row_left2right_summary)
        top_right_merged = jnp.einsum(col_attn_expr, top_right_col_att,
                                      row_right2left_summary)

        joint_merged = (col_merged + row_merged + previous_row_merged +
                        top_left_merged + top_right_merged).reshape([
                            bsize, num_seg * cfg.max_seg_len, cfg.num_heads,
                            head_dim
                        ])
        x = DenseGeneral(features=features,
                         axis=(-2, -1),
                         kernel_init=cfg.kernel_init,
                         bias_init=cfg.bias_init,
                         use_bias=False,
                         dtype=cfg.dtype)(joint_merged)

        return x
예제 #17
0
    def __call__(self, input_qkv):
        cfg = self.config
        cfg.max_len % cfg.max_seg_len == 0
        bsize = input_qkv.shape[0]
        features = self.out_features or input_qkv.shape[-1]
        num_seg = cfg.max_len // cfg.max_seg_len
        x_sqr = input_qkv.reshape(
            [bsize, num_seg, cfg.max_seg_len, input_qkv.shape[-1]])
        q_row_local, key_row_local, value_row_local, head_dim = get_qkv(
            cfg, x_sqr)
        local_logits = jnp.einsum('...qhd,...khd->...qhk', q_row_local,
                                  key_row_local)
        row_probs = jax.nn.softmax(local_logits)
        if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')
            row_probs = dropatt(row_probs, dropout_rng,
                                1 - cfg.attention_dropout_rate)
        row_attn_out = jnp.einsum('...qhk,...khd->...qhd', row_probs,
                                  value_row_local)

        key_row = DenseGeneral(features=input_qkv.shape[-1],
                               axis=(-2, -1),
                               kernel_init=cfg.kernel_init,
                               bias_init=cfg.bias_init,
                               use_bias=False,
                               dtype=cfg.dtype)(row_attn_out)
        key_row = nn.Dropout(rate=cfg.dropout_rate)(
            key_row, deterministic=cfg.deterministic)
        key_row = key_row + x_sqr
        key_row = nn.LayerNorm(dtype=cfg.dtype)(key_row)
        key_row = DenseGeneral(axis=-1,
                               features=(cfg.num_heads, head_dim),
                               kernel_init=cfg.kernel_init,
                               bias_init=cfg.bias_init,
                               use_bias=False,
                               dtype=cfg.dtype)(key_row)
        idx_cols = jnp.arange(cfg.max_seg_len)
        local_mask = nn.make_attention_mask(idx_cols,
                                            idx_cols,
                                            jnp.less,
                                            extra_batch_dims=1)
        local_mask = jnp.expand_dims(local_mask, axis=-2) * -1e10
        local_logits = local_logits + local_mask

        global_logits = jnp.einsum('bqlhd,bklhd->bqlhk', q_row_local, key_row)
        idx_rows = jnp.arange(num_seg)
        global_mask = nn.make_attention_mask(idx_rows, idx_rows,
                                             jnp.less_equal)
        global_mask = global_mask[:, :, jnp.newaxis, jnp.newaxis, :] * -1e10
        global_logits = global_logits + global_mask

        joint_logits = jnp.concatenate((local_logits, global_logits), axis=-1)
        attn_probs = jax.nn.softmax(joint_logits, axis=-1)
        local_att, global_att = jnp.split(attn_probs, [cfg.max_seg_len],
                                          axis=-1)
        if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')
            local_att = dropatt(local_att, dropout_rng,
                                1 - cfg.attention_dropout_rate)
        local_merged = jnp.einsum('bsqhk,bskhd->bsqhd', local_att,
                                  value_row_local)
        global_merged = jnp.einsum('bqlhv,bvlhd->bqlhd', global_att,
                                   row_attn_out)
        joint_merged = jnp.reshape(
            local_merged + global_merged,
            [bsize, cfg.max_len, cfg.num_heads, head_dim])
        x = DenseGeneral(features=features,
                         axis=(-2, -1),
                         kernel_init=cfg.kernel_init,
                         bias_init=cfg.bias_init,
                         use_bias=False,
                         dtype=cfg.dtype)(joint_merged)
        return x
예제 #18
0
    def __call__(self, inputs: Array, train: bool,
                 **kwargs) -> Tuple[Array, Mapping[str, Any]]:
        out = {}

        x = inputs
        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(features=self.hidden_size,
                    kernel_size=self.patches.size,
                    strides=self.patches.size,
                    padding='VALID',
                    name='embedding')(x)

        # Here, x is a grid of embeddings.
        # TODO(dusenberrymw): Switch to self.sow(.).
        out['stem'] = x

        # Transformer.
        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # If we want to add a class token, add it here.
        if self.classifier == 'token':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)

        x, _ = vit_batchensemble.BatchEnsembleEncoder(name='Transformer',
                                                      **self.transformer)(
                                                          x, train=train)
        out['transformed'] = x

        if self.classifier == 'token':
            x = x[:, 0]
        elif self.classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
        else:
            raise ValueError(f'Invalid classifier={self.classifier}')

        out['head_input'] = x

        if self.representation_size is not None:
            x = ed.nn.DenseBatchEnsemble(
                self.representation_size,
                self.transformer.get('ens_size'),
                activation=None,
                alpha_init=ed.nn.utils.make_sign_initializer(
                    self.transformer.get('random_sign_init')),
                gamma_init=ed.nn.utils.make_sign_initializer(
                    self.transformer.get('random_sign_init')),
                name='pre_logits')(x)
            out['pre_logits'] = x
            x = nn.tanh(x)
        else:
            x = vit.IdentityLayer(name='pre_logits')(x)
            out['pre_logits'] = x

        # TODO(markcollier): Fix base model without using stop_gradient.
        if self.fix_base_model:
            x = jax.lax.stop_gradient(x)

        if self.use_gp:
            if self.covmat_momentum < 0.:
                gp_layer_kwargs = {'covmat_kwargs': {'momentum': None}}
            else:
                gp_layer_kwargs = {
                    'covmat_kwargs': {
                        'momentum': self.covmat_momentum
                    }
                }

            if self.multiclass:
                raise NotImplementedError(
                    'Multi-class HetSNGP layer not available.')
            else:
                gp_layer = ed.nn.MCSigmoidDenseFASNGPBE(
                    num_outputs=self.num_classes,
                    num_factors=self.num_factors,
                    temperature=self.temperature,
                    parameter_efficient=self.param_efficient,
                    train_mc_samples=self.mc_samples,
                    test_mc_samples=self.mc_samples,
                    ens_size=self.transformer.get('ens_size'),
                    logits_only=True,
                    name='head',
                    **gp_layer_kwargs)
            x_gp = gp_layer(x, training=train, **kwargs)

            # Gaussian process layer output: a tuple of logits, covmat, and optionally
            # random features.
            out['logits'] = x_gp[0]
            out['covmat'] = x_gp[1]

            logits = x_gp[0]
        else:
            # Note we're using non-BE layers.
            if self.multiclass:
                output_layer = ed.nn.MCSoftmaxDenseFA(
                    self.num_classes,
                    self.num_factors,
                    self.temperature,
                    self.param_efficient,
                    self.mc_samples,
                    self.mc_samples,
                    logits_only=True,
                    return_locs=self.return_locs,
                    name='head')
            else:
                output_layer = ed.nn.MCSigmoidDenseFA(
                    num_outputs=self.num_classes,
                    num_factors=self.num_factors,
                    temperature=self.temperature,
                    parameter_efficient=self.param_efficient,
                    train_mc_samples=self.mc_samples,
                    test_mc_samples=self.mc_samples,
                    logits_only=True,
                    return_locs=self.return_locs,
                    name='head')
            logits = output_layer(x)
            out['logits'] = logits

        if not train:
            if self.multiclass:
                logits = log_average_softmax_probs(
                    jnp.asarray(
                        jnp.split(logits, self.transformer.get('ens_size'))))
                out['pre_ens_logits'] = out['pre_logits']
                out['pre_logits'] = log_average_softmax_probs(
                    jnp.asarray(
                        jnp.split(out['pre_logits'],
                                  self.transformer.get('ens_size'))))
            else:
                logits = log_average_sigmoid_probs(
                    jnp.asarray(
                        jnp.split(logits, self.transformer.get('ens_size'))))
                out['pre_ens_logits'] = out['pre_logits']
                out['pre_logits'] = log_average_sigmoid_probs(
                    jnp.asarray(
                        jnp.split(out['pre_logits'],
                                  self.transformer.get('ens_size'))))

        return logits, out
예제 #19
0
    def __call__(self, input_qkv):
        cfg = self.config
        cfg.max_len % cfg.max_seg_len == 0
        bsize = input_qkv.shape[0]
        features = self.out_features or input_qkv.shape[-1]
        query, key, value, head_dim = get_qkv(cfg, input_qkv)

        num_seg = cfg.max_len // cfg.max_seg_len
        cur_query = query.reshape(
            [-1, cfg.max_seg_len, query.shape[-2], query.shape[-1]])
        merged_query = jnp.max(cur_query, axis=1,
                               keepdims=True) * jnp.sqrt(head_dim)
        cur_key = key.reshape(
            [-1, cfg.max_seg_len, key.shape[-2], key.shape[-1]])
        cur_value = value.reshape(
            [-1, cfg.max_seg_len, value.shape[-2], value.shape[-1]])
        dropout_rng = None
        if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')
        s = dot_product_attention(merged_query,
                                  cur_key,
                                  cur_value,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=cfg.attention_dropout_rate,
                                  broadcast_dropout=False,
                                  deterministic=cfg.deterministic,
                                  dtype=cfg.dtype)
        span_val = jnp.reshape(s, [bsize, -1, s.shape[-2], s.shape[-1]])
        span_key = jnp.max(cur_key, axis=1, keepdims=True)
        # (bsize, n_seg, n_head, dim_per_head)
        span_key = jnp.reshape(
            span_key, [bsize, -1, span_key.shape[-2], span_key.shape[-1]])

        local_mask = make_causal_mask(cur_query,
                                      length_axis=1).transpose([0, 2, 1, 3])
        local_bias = lax.select(
            local_mask > 0,
            jnp.full(local_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(local_mask.shape, -1e10).astype(cfg.dtype))
        # (bsize * n_seg, seg_len, n_head, seg_len)
        local_logits = jnp.einsum('...qhd,...khd->...qhk', cur_query,
                                  cur_key) + local_bias
        local_logits = jnp.reshape(local_logits,
                                   [bsize, -1, cfg.num_heads, cfg.max_seg_len])
        idx = jnp.broadcast_to(jnp.arange(span_key.shape[1], dtype=jnp.int32),
                               span_key.shape[:2])
        prev_mask = nn.make_attention_mask(idx,
                                           idx,
                                           jnp.greater,
                                           extra_batch_dims=0,
                                           dtype=jnp.float32).transpose(
                                               [0, 2, 1, 3])
        prev_mask = jnp.repeat(prev_mask, cfg.max_seg_len, axis=-3)
        prev_bias = lax.select(
            prev_mask > 0,
            jnp.full(prev_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(prev_mask.shape, -1e10).astype(cfg.dtype))
        # (bsize, max_len, n_head, num_segs)
        prev_logits = jnp.einsum('...qhd,...khd->...qhk', query,
                                 span_key) + prev_bias
        joint_logits = jnp.concatenate((local_logits, prev_logits), axis=-1)
        # (bsize x max_len,  n_head, seg_len + num_segs)
        attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype)
        local_att, prev_att = jnp.split(attn_weights, [cfg.max_seg_len],
                                        axis=-1)
        local_att = local_att.reshape(
            [bsize * num_seg, cfg.max_seg_len, cfg.num_heads, cfg.max_seg_len])
        local_merged = jnp.einsum('...qhk,...khd->...qhd', local_att,
                                  cur_value)
        prev_merged = jnp.einsum('...qhk,...khd->...qhd', prev_att, span_val)
        joint_merged = jnp.reshape(local_merged,
                                   prev_merged.shape) + prev_merged
        x = DenseGeneral(features=features,
                         axis=(-2, -1),
                         kernel_init=cfg.kernel_init,
                         bias_init=cfg.bias_init,
                         use_bias=False,
                         dtype=cfg.dtype)(joint_merged)
        return x
예제 #20
0
        def critic_loss(q_params, policy_params, target_q_params, transitions,
                        key):
            batch_size = transitions.observation.shape[0]
            # Note: We might be able to speed up the computation for some of the
            # baselines to making a single network that returns all the values. This
            # avoids computing some of the underlying representations multiple times.
            if config.use_td:
                # For TD learning, the diagonal elements are the immediate next state.
                s, g = jnp.split(transitions.observation, [config.obs_dim],
                                 axis=1)
                next_s, _ = jnp.split(transitions.next_observation,
                                      [config.obs_dim],
                                      axis=1)
                if config.add_mc_to_td:
                    next_fraction = (1 - config.discount) / (
                        (1 - config.discount) + 1)
                    num_next = int(batch_size * next_fraction)
                    new_g = jnp.concatenate([
                        obs_to_goal(next_s[:num_next]),
                        g[num_next:],
                    ],
                                            axis=0)
                else:
                    new_g = obs_to_goal(next_s)
                obs = jnp.concatenate([s, new_g], axis=1)
                transitions = transitions._replace(observation=obs)
            I = jnp.eye(batch_size)  # pylint: disable=invalid-name
            logits = networks.q_network.apply(q_params,
                                              transitions.observation,
                                              transitions.action)

            if config.use_td:
                # Make sure to use the twin Q trick.
                assert len(logits.shape) == 3

                # We evaluate the next-state Q function using random goals
                s, g = jnp.split(transitions.observation, [config.obs_dim],
                                 axis=1)
                del s
                next_s = transitions.next_observation[:, :config.obs_dim]
                goal_indices = jnp.roll(
                    jnp.arange(batch_size, dtype=jnp.int32), -1)
                g = g[goal_indices]
                transitions = transitions._replace(
                    next_observation=jnp.concatenate([next_s, g], axis=1))
                next_dist_params = networks.policy_network.apply(
                    policy_params, transitions.next_observation)
                next_action = networks.sample(next_dist_params, key)
                next_q = networks.q_network.apply(
                    target_q_params, transitions.next_observation,
                    next_action)  # This outputs logits.
                next_q = jax.nn.sigmoid(next_q)
                next_v = jnp.min(next_q, axis=-1)
                next_v = jax.lax.stop_gradient(next_v)
                next_v = jnp.diag(next_v)
                # diag(logits) are predictions for future states.
                # diag(next_q) are predictions for random states, which correspond to
                # the predictions logits[range(B), goal_indices].
                # So, the only thing that's meaningful for next_q is the diagonal. Off
                # diagonal entries are meaningless and shouldn't be used.
                w = next_v / (1 - next_v)
                w_clipping = 20.0
                w = jnp.clip(w, 0, w_clipping)
                # (B, B, 2) --> (B, 2), computes diagonal of each twin Q.
                pos_logits = jax.vmap(jnp.diag, -1, -1)(logits)
                loss_pos = optax.sigmoid_binary_cross_entropy(
                    logits=pos_logits, labels=1)  # [B, 2]

                neg_logits = logits[jnp.arange(batch_size), goal_indices]
                loss_neg1 = w[:, None] * optax.sigmoid_binary_cross_entropy(
                    logits=neg_logits, labels=1)  # [B, 2]
                loss_neg2 = optax.sigmoid_binary_cross_entropy(
                    logits=neg_logits, labels=0)  # [B, 2]

                if config.add_mc_to_td:
                    loss = ((1 + (1 - config.discount)) * loss_pos +
                            config.discount * loss_neg1 + 2 * loss_neg2)
                else:
                    loss = ((1 - config.discount) * loss_pos +
                            config.discount * loss_neg1 + loss_neg2)
                # Take the mean here so that we can compute the accuracy.
                logits = jnp.mean(logits, axis=-1)

            else:  # For the MC losses.

                def loss_fn(_logits):  # pylint: disable=invalid-name
                    if config.use_cpc:
                        return (optax.softmax_cross_entropy(logits=_logits,
                                                            labels=I) +
                                0.01 * jax.nn.logsumexp(_logits, axis=1)**2)
                    else:
                        return optax.sigmoid_binary_cross_entropy(
                            logits=_logits, labels=I)

                if len(logits.shape) == 3:  # twin q
                    # loss.shape = [.., num_q]
                    loss = jax.vmap(loss_fn, in_axes=2, out_axes=-1)(logits)
                    loss = jnp.mean(loss, axis=-1)
                    # Take the mean here so that we can compute the accuracy.
                    logits = jnp.mean(logits, axis=-1)
                else:
                    loss = loss_fn(logits)

            loss = jnp.mean(loss)
            correct = (jnp.argmax(logits, axis=1) == jnp.argmax(I, axis=1))
            logits_pos = jnp.sum(logits * I) / jnp.sum(I)
            logits_neg = jnp.sum(logits * (1 - I)) / jnp.sum(1 - I)
            if len(logits.shape) == 3:
                logsumexp = jax.nn.logsumexp(logits[:, :, 0], axis=1)**2
            else:
                logsumexp = jax.nn.logsumexp(logits, axis=1)**2
            metrics = {
                'binary_accuracy': jnp.mean((logits > 0) == I),
                'categorical_accuracy': jnp.mean(correct),
                'logits_pos': logits_pos,
                'logits_neg': logits_neg,
                'logsumexp': logsumexp.mean(),
            }

            return loss, metrics
예제 #21
0
 def __call__(self, x):
     x_split = jnp.split(x, self.splits, 1)
     x_out = [c(x_split[i]) for i, c in enumerate(self)]
     x = jnp.concatenate(x_out, axis=1)
     return x
예제 #22
0
def train_step(
    model_fn,
    config,
    lr_fn,
    hwfnf,
    state,
    batch,
    coords=None,
    rng=None,
):
    """Perform a single training step."""
    rng_0, rng_1, rng_2, rng_3, rng_4 = random.split(rng, 5)
    inputs, target = batch
    hwf, near, far = hwfnf
    apply_coarse, apply_fine = model_fn
    opt_coarse, opt_fine = state.optimizer_coarse, state.optimizer_fine

    if not config.batching:
        rays = prepare_rays(None, hwf, config, near, far, inputs[:3, :4], None)
        if coords is None:
            coords = jnp.meshgrid(jnp.arange(hwf[0]),
                                  jnp.arange(hwf[1]),
                                  indexing="ij")
            coords = jnp.stack(coords, axis=-1).reshape([-1, 2])
        select_idx = random.choice(
            rng_0,
            coords.shape[0],
            shape=[config.num_rand],
            replace=False,
        )
        select_idx = coords[select_idx]
        rays = rays[select_idx[:, 0], select_idx[:, 1]]
        target = target[select_idx[:, 0], select_idx[:, 1]]
    else:
        rays = inputs

    *rays, viewdirs = jnp.split(rays, [3, 6, 7, 8], axis=-1)
    raw2outputs_ = functools.partial(
        raw2outputs,
        raw_noise_std=config.raw_noise_std,
        white_bkgd=config.white_bkgd,
    )

    def loss_fn(params_coarse, params_fine=None):
        """Loss function used for training."""
        pts, z_vals = render_rays(rays, config, rng_1)
        raw_c = apply_coarse({
            "params": params_coarse
        }, pts, viewdirs).reshape([config.num_rand, config.num_samples, 4])
        coarse_res, weights = raw2outputs_(raw_c, z_vals, rays[1], rng=rng_2)
        loss_c = jnp.mean((coarse_res["rgb"] - target)**2.0)
        coarse_res["raw"] = raw_c.astype(jnp.float32)
        coarse_res["loss"] = loss_c

        if config.num_importance > 0:
            pts, z_vals, _ = render_rays_fine(rays[:2], z_vals, weights,
                                              config.num_importance,
                                              config.perturb, rng_3)
            raw_f = apply_fine({
                "params": params_fine
            }, pts, viewdirs).reshape([
                config.num_rand, config.num_samples + config.num_importance, 4
            ])
            fine_res, _ = raw2outputs_(raw_f, z_vals, rays[1], rng=rng_4)
            loss_f = jnp.mean((fine_res["rgb"] - target)**2.0)
            fine_res["raw"] = raw_f.astype(jnp.float32)
            fine_res["loss"] = loss_f
            psnr = psnr_fn(loss_f)
        else:
            psnr = psnr_fn(loss_c)
            loss_f = 0
            fine_res = None

        loss = loss_c + loss_f
        return loss, (psnr, coarse_res, fine_res)

    lr = lr_fn(state.step)
    if config.num_importance > 0:
        aux, (grad_coarse,
              grad_fine) = jax.value_and_grad(loss_fn,
                                              argnums=[0, 1],
                                              has_aux=True)(opt_coarse.target,
                                                            opt_fine.target)

        grad_fine = lax.pmean(grad_fine, axis_name="batch")
        new_opt_fine = opt_fine.apply_gradient(grad_fine, learning_rate=lr)
    else:
        aux, grad_coarse = jax.value_and_grad(loss_fn,
                                              has_aux=True)(opt_coarse.target)
        new_opt_fine = None

    grad_coarse = lax.pmean(grad_coarse, axis_name="batch")
    new_opt_coarse = opt_coarse.apply_gradient(grad_coarse, learning_rate=lr)

    new_state = state.replace(
        step=state.step + 1,
        optimizer_coarse=new_opt_coarse,
        optimizer_fine=new_opt_fine,
    )
    loss, (psnr, coarse_res, fine_res) = aux
    metrics = {
        "loss": loss,
        "loss_c": coarse_res["loss"],
        "psnr": psnr,
        "psnr_c": psnr_fn(coarse_res["loss"]),
        "lr": lr,
    }
    if config.num_importance > 0:
        metrics["loss_f"] = fine_res["loss"]
        metrics["psnr_f"] = psnr_fn(fine_res["loss"])
    metrics = lax.pmean(metrics, axis_name="batch")
    return new_state, metrics, coarse_res, fine_res
예제 #23
0
def main():
    with open('reuters_vocab.pkl', 'rb') as f:
        vocab = pickle.load(f)

    v_dim = len(vocab['num_to_word'])
    e_dim = 1024

    prng_key = random.PRNGKey(0xdeadbeef)
    words = jnp.zeros(v_dim, dtype=jnp.float32)
    word_ix = random.uniform(prng_key, (1, )) * v_dim
    word_ix = int(jnp.floor(word_ix)[0])
    words = jops.index_update(words, word_ix, 1.0)

    # first create the architecture of the model
    mdl = Embedding(v_dim, e_dim)
    # then complete the model spec by giving an example input tensor
    params = mdl.init(prng_key, words)
    # now apply the params to the model with the input
    out = mdl.apply(params, words)
    print(f'out: {out}')
    print(f'shape: {out.shape}')

    # let's train the model on nltk's reuters dataset
    from nltk.corpus import reuters
    train_texts = []
    for fname in reuters.fileids():
        text = reuters.words(fname)
        train_texts.append(text)

    # now generate word-context elements
    window_size = 2
    word_pairs = []
    for words in tqdm(train_texts, desc='make train set'):
        for word_ix, word in enumerate(words):
            for offset in range(1, window_size + 1):
                back_context = word_ix - offset
                if back_context >= 0:
                    word_pairs.append((word, words[back_context]))
                fwd_context = word_ix + offset
                if fwd_context < len(words):
                    word_pairs.append((word, words[fwd_context]))

    # convert words to vocab IDs
    w2n = vocab['word_to_num']

    id_pairs = []
    for word_pair in tqdm(word_pairs, desc='gen word pairs'):
        word = word_pair[0]
        context = word_pair[1]
        if word in w2n and context in w2n:
            w_id, c_id = w2n[word], w2n[context]
            id_pairs.append((w_id, c_id))
    id_pairs = jnp.array(id_pairs)
    print(f'train pairs: {len(id_pairs)}')

    # run grad desc
    id_pairs = id_pairs[0:len(id_pairs) // 100]
    lr = 0.3
    batch_size = 2500

    # TEST: what if I run one at a time?
    '''
    loss_fn = lambda x, y : nll_loss_fn(mdl, params, x, y)
    grad_fn = jax.value_and_grad(loss_fn)
    grad_calc_fn = lambda params, x, y : grad_fn(params, x, y)
    param_update_fn = lambda old, grad: old -  lr * grad

    template_vec = jnp.zeros(v_dim, dtype=jnp.float32)
    for epoch in trange(5):
        for pair in tqdm(id_pairs):
            x = jops.index_update(template_vec, pair[0], 1.)
            y = jops.index_update(template_vec, pair[1], 1.)
            loss_val, grad = grad_calc_fn(params, x, y)
            params = jax.tree_multimap(param_update_fn, paramd, grad)
    import pdb; pdb.set_trace()
    pass
    '''
    # TEST END

    batches = jnp.split(id_pairs,
                        jnp.arange(batch_size, len(id_pairs), batch_size))

    for epoch in trange(1):
        # TODO: shuffle & batch id_pairs
        pbar = trange(len(batches), desc=f'epoch:--- - loss:------')
        for batch in batches:
            x_vals, y_vals = __id_to_one_hot(batch, v_dim)
            loss_fn = nll_loss_fn(mdl, params, x_vals, y_vals)
            grad_fn = jax.value_and_grad(loss_fn)
            loss_val, grad = grad_fn(params)
            params = jax.tree_multimap(lambda old, grad: old - lr * grad,
                                       params, grad)
            pbar.set_description(f'epoch:{epoch:03d} - loss:{loss_val:0.4f}')
            pbar.update()

    import pdb
    pdb.set_trace()
    print('done!')
예제 #24
0
 def unravel(arr):
   chunks = jnp.split(arr, indices[:-1])
   with warnings.catch_warnings():
     warnings.simplefilter("ignore")  # ignore complex-to-real cast warning
     return [lax.convert_element_type(chunk.reshape(shape), dtype)
             for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]
예제 #25
0
def unstack(a, axis=0):
    """The opposite of stack()."""
    shape = a.shape
    return [jnp.squeeze(b, axis=axis) for b in \
            jnp.split(a, shape[axis], axis=axis)]
 def shift_and_log_scale_fn(net_params, x1):
     s = net_apply(net_params, x1)
     return np.split(s, 2, axis=1)
예제 #27
0
 def reassemble_concat(x):
     x = tuple(jnp.split(x, 2, axis=0))
     return reassemble(x)
예제 #28
0
    def __call__(self, x):
        """Normalizes the input using batch statistics.

    Args:
      x: the input to be normalized.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        x = jnp.asarray(x, jnp.float32)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # we detect if we're in initialization via empty variable tree.
        initializing = not self.has_variable('batch_stats', 'mean')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if self.use_running_average:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            y = y + bias
        return jnp.asarray(y, self.dtype)
예제 #29
0
 def test_split(self):
   self.check(lambda x: jnp.split(x, 2), ['2*n'], ['n', 'n'], dict(n=4),
              [(8,)], ['float_'], jtu.rand_default(self.rng()))
   self.check(lambda x: jnp.split(x, [10]), ['n'], ['10', 'n+-10'], dict(n=12),
              [(12,)], ['float_'], jtu.rand_default(self.rng()))
예제 #30
0
 def from_tensor(cls, tensor, normalize=False):
     quaternion, tx, ty, tz = jnp.split(tensor, [4, 5, 6], axis=-1)
     return cls(quaternion, [tx[..., 0], ty[..., 0], tz[..., 0]],
                normalize=normalize)