Exemplo n.º 1
0
  def dot_product_attention(self,
                            query,
                            key,
                            value,
                            dtype=jnp.float32,
                            bias=None,
                            axis=None,
                            broadcast_dropout=True,
                            dropout_rng=None,
                            dropout_rate=0.,
                            deterministic=False,
                            precision=None):

    assert key.shape[:-1] == value.shape[:-1]
    assert (query.shape[0:1] == key.shape[0:1] and
            query.shape[-1] == key.shape[-1])
    if axis is None:
      axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
      axis = (axis,)
    assert key.ndim == query.ndim
    assert key.ndim == value.ndim
    for ax in axis:
      if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
        raise ValueError('Attention axis must be between the batch '
                         'axis and the last-two axes.')
    n = key.ndim

    # Constructing projection tensor.
    if self.redraw_features:
      # TODO(kchoro): Get rid of the constant below.
      query_seed = lax.convert_element_type(
          jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
      rng = random.PRNGKey(query_seed)
      self.projection_matrix = self.draw_weights(rng)

    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(onp.delete(range(n), axis + (n - 1,)))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
    qk_perm = batch_dims + axis + (n - 1,)
    k_extra_perm = axis + batch_dims + (n - 1,)
    key_extra = key.transpose(k_extra_perm)
    key = key.transpose(qk_perm)
    query = query.transpose(qk_perm)
    # v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
    v_perm = batch_dims + axis + (n - 1,)
    value = value.transpose(v_perm)
    batch_dims_t = tuple(range(len(batch_dims)))
    attention_dims_t = tuple(
        range(len(batch_dims),
              len(batch_dims) + len(axis)))

    # Constructing tensors Q^{'} and K^{'}.
    query_prime = self.kernel_feature_creator(query, self.projection_matrix,
                                              attention_dims_t, batch_dims_t,
                                              precision, True)
    key_prime = self.kernel_feature_creator(key, self.projection_matrix,
                                            attention_dims_t, batch_dims_t,
                                            precision, False)

    if self.unidirectional:
      index = attention_dims_t[0]
      z_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
          key_prime.shape[-1],) + (value.shape[-1],)

      W = _numerator(z_slice_shape, precision,
                     jnp.moveaxis(query_prime, index, 0),
                     jnp.moveaxis(key_prime, index, 0),
                     jnp.moveaxis(value, index, 0), self.lax_scan_unroll)

      # Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
      W = jnp.moveaxis(W, 0, index)

      if not self.renormalize_attention:
        # Unidirectional, not-normalized attention.
        perm_inv = _invert_perm(qk_perm)
        result = W.transpose(perm_inv)
        return result
      else:
        # Unidirectional, normalized attention.
        thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
            key_extra.shape[0:len(axis)])

        index = attention_dims_t[0]
        t_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
            key_prime.shape[-1],)
        R = _denominator(t_slice_shape, precision,
                         jnp.moveaxis(query_prime, index, 0),
                         jnp.moveaxis(key_prime, index, 0),
                         self.lax_scan_unroll)

        R = jnp.moveaxis(R, 0, index)
    else:
      contract_query = tuple(
          range(len(batch_dims) + len(axis),
                len(batch_dims) + len(axis) + 1))
      contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
      # Constructing Z = (K^{'})^{T}V
      # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
      Z = lax.dot_general(
          key_prime,
          value,
          ((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)),
          precision=precision)
      # Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
      # q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
      # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
      # W (bs,  <non-attention dims>, num_heads, <attention dims>, channels_v)
      W = lax.dot_general(
          query_prime,
          Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)),
          precision=precision)
      if not self.renormalize_attention:
        # Bidirectional, not-normalized attention.
        perm_inv = _invert_perm(qk_perm)
        result = W.transpose(perm_inv)
        return result
      else:
        # Bidirectional, normalized attention.
        thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
            key_extra.shape[0:len(axis)])
        contract_key = tuple(
            range(len(batch_dims),
                  len(batch_dims) + len(axis)))
        contract_thick_all_ones = tuple(
            range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim))
        # Construct T = (K^{'})^{T} 1_L
        # k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
        T = lax.dot_general(
            key_prime,
            thick_all_ones, ((contract_key, contract_thick_all_ones),
                             (batch_dims_t, batch_dims_t)),
            precision=precision)

        # Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
        # q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
        # T   (bs, <non-attention dims>, num_heads, channels_m)
        R = lax.dot_general(
            query_prime,
            T, (((query_prime.ndim - 1,), (T.ndim - 1,)),
                (batch_dims_t, range(0,
                                     len(T.shape) - 1))),
            precision=precision)

    R = R + 2 * self.numerical_stabilizer * (
        jnp.abs(R) <= self.numerical_stabilizer)
    R = jnp.reciprocal(R)
    R = jnp.expand_dims(R, len(R.shape))
    # W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
    # R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
    result = W * R
    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    result = result.transpose(perm_inv)
    return result
Exemplo n.º 2
0
def pad_trajectories(trajectories, boundary=20):
    """Pad trajectories to a bucket length that is a multiple of boundary.

  Args:
    trajectories: list[(observation, actions, rewards)], where each observation
      is shaped (t+1,) + OBS and actions & rewards are shaped (t,), with the
      length of the list being B (batch size).
    boundary: int, bucket length, the actions and rewards are padded to integer
      multiples of boundary.

  Returns:
    tuple: (padding lengths, reward_mask, padded_observations, padded_actions,
        padded_rewards) where padded_observations is shaped (B, T+1) + OBS and
        padded_actions, padded_rewards & reward_mask are shaped (B, T).
        Where T is max(t) rounded up to an integer multiple of boundary.
        padded_length is how much padding we've added and
        reward_mask is 1s for actual rewards and 0s for the padding.
  """

    # Let's compute max(t) over all trajectories.
    t_max = max(r.shape[0] for (_, _, r) in trajectories)

    # t_max is rounded to the next multiple of `boundary`
    boundary = int(boundary)
    bucket_length = boundary * int(np.ceil(float(t_max) / boundary))

    # So all obs will be padded to t_max + 1 and actions and rewards to t_max.
    padded_observations = []
    padded_actions = []
    padded_rewards = []
    padded_lengths = []
    reward_masks = []
    for (o, a, r) in trajectories:
        # Determine the amount to pad, this holds true for obs, actions and rewards.
        num_to_pad = bucket_length + 1 - o.shape[0]
        padded_lengths.append(num_to_pad)
        if num_to_pad == 0:
            padded_observations.append(o)
            padded_actions.append(a)
            padded_rewards.append(r)
            reward_masks.append(onp.ones_like(r, dtype=np.int32))
            continue

        # First pad observations.
        padding_config = [(0, num_to_pad, 0)]
        for _ in range(o.ndim - 1):
            padding_config.append((0, 0, 0))
        padding_config = tuple(padding_config)

        padding_value = get_padding_value(o.dtype)
        action_padding_value = get_padding_value(a.dtype)
        reward_padding_value = get_padding_value(r.dtype)

        padded_obs = lax.pad(o, padding_value, padding_config)
        padded_observations.append(padded_obs)

        # Now pad actions and rewards.
        assert a.ndim == 1 and r.ndim == 1
        padding_config = ((0, num_to_pad, 0), )

        padded_action = lax.pad(a, action_padding_value, padding_config)
        padded_actions.append(padded_action)
        padded_reward = lax.pad(r, reward_padding_value, padding_config)
        padded_rewards.append(padded_reward)

        # Also create the mask to use later.
        reward_mask = onp.ones_like(r, dtype=np.int32)
        reward_masks.append(lax.pad(reward_mask, 0, padding_config))

    return padded_lengths, np.stack(reward_masks), np.stack(
        padded_observations), np.stack(padded_actions), np.stack(
            padded_rewards)
Exemplo n.º 3
0
bessel_i0e = utils.copy_docstring(tf.math.bessel_i0e,
                                  lambda x, name=None: scipy_special.i0e(x))

bessel_i1 = utils.copy_docstring(tf.math.bessel_i1,
                                 lambda x, name=None: scipy_special.i1(x))

bessel_i1e = utils.copy_docstring(tf.math.bessel_i1e,
                                  lambda x, name=None: scipy_special.i1e(x))

betainc = utils.copy_docstring(
    tf.math.betainc, lambda a, b, x, name=None: scipy_special.betainc(a, b, x))

bincount = utils.copy_docstring(tf.math.bincount, _bincount)

ceil = utils.copy_docstring(tf.math.ceil, lambda x, name=None: np.ceil(x))

# confusion_matrix = utils.copy_docstring(
#     tf.math.confusion_matrix,
#     lambda labels, predictions, num_classes=None, weights=None,
#     dtype=tf.int32, name=None: ...)

conj = utils.copy_docstring(tf.math.conj, lambda x, name=None: np.conj(x))

cos = utils.copy_docstring(tf.math.cos, lambda x, name=None: np.cos(x))

cosh = utils.copy_docstring(tf.math.cosh, lambda x, name=None: np.cosh(x))

count_nonzero = utils.copy_docstring(
    tf.math.count_nonzero,
    lambda input, axis=None, keepdims=None, dtype=tf.int64, name=None: (  # pylint: disable=g-long-lambda
Exemplo n.º 4
0
def _get_num_steps(step_size, trajectory_length):
    num_steps = jnp.ceil(trajectory_length / step_size)
    # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
    # if jax_enable_x64 is False
    return num_steps.astype(jnp.result_type(int))
Exemplo n.º 5
0
def pad_to_pow2(tensor, axis):
    size = tensor.shape[axis]
    new_size = int(np.power(2, np.ceil(size)))
    return pad_along_axis(tensor, new_size, axis)
Exemplo n.º 6
0
def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
    """A wrapper that handles non-maximum suppression.

  Assumption:
    * The boxes are sorted by scores unless the box is a dot (all coordinates
      are zero).
    * Boxes with higher scores can be used to suppress boxes with lower scores.

  The overal design of the algorithm is to handle boxes tile-by-tile:

  boxes = boxes.pad_to_multiply_of(tile_size)
  num_tiles = len(boxes) // tile_size
  output_boxes = []
  for i in range(num_tiles):
    box_tile = boxes[i*tile_size : (i+1)*tile_size]
    for j in range(i - 1):
      suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
      iou = _bbox_overlap(box_tile, suppressing_tile)
      # if the box is suppressed in iou, clear it to a dot
      box_tile *= _update_boxes(iou)
    # Iteratively handle the diagnal tile.
    iou = _box_overlap(box_tile, box_tile)
    iou_changed = True
    while iou_changed:
      # boxes that are not suppressed by anything else
      suppressing_boxes = _get_suppressing_boxes(iou)
      # boxes that are suppressed by suppressing_boxes
      suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
      # clear iou to 0 for boxes that are suppressed, as they cannot be used
      # to suppress other boxes any more
      new_iou = _clear_iou(iou, suppressed_boxes)
      iou_changed = (new_iou != iou)
      iou = new_iou
    # remaining boxes that can still suppress others, are selected boxes.
    output_boxes.append(_get_suppressing_boxes(iou))
    if len(output_boxes) >= max_output_size:
      break

  Args:
    scores: a tensor with a shape of [batch_size, anchors].
    boxes: a tensor with a shape of [batch_size, anchors, 4].
    max_output_size: a scalar integer `Tensor` representing the maximum number
      of boxes to be selected by non max suppression.
    iou_threshold: a float representing the threshold for deciding whether boxes
      overlap too much with respect to IOU.
  Returns:
    nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
      dtype as input scores.
    nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
      same dtype as input boxes.
  """
    batch_size = boxes.shape[0]
    num_boxes = boxes.shape[1]
    pad = int(jnp.ceil(
        float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
    boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
    scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
    num_boxes += pad

    def _loop_cond(in_args):
        unused_boxes, unused_threshold, output_size, idx = in_args
        return jnp.logical_and(
            jnp.min(output_size) < max_output_size,
            idx < num_boxes // _NMS_TILE_SIZE)

    selected_boxes, _, output_size, _ = lax.while_loop(
        _loop_cond, _suppression_loop_body,
        (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0))
    idx = num_boxes - lax.top_k(
        jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
        jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
        max_output_size)[0].astype(jnp.int32)
    idx = jnp.minimum(idx, num_boxes - 1)
    idx = jnp.reshape(
        idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
    boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx],
                        [batch_size, max_output_size, 4])
    boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) <
                     jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype)
    scores = jnp.reshape(
        jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
    scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1]) <
                       jnp.reshape(output_size, [-1, 1])).astype(scores.dtype)
    return scores, boxes
Exemplo n.º 7
0
 def row(i):
     return jnp.ceil(jnp.log2(i + 1))
Exemplo n.º 8
0
def r_constant_jax(currinfo, frames, T_n, rp, adj=True, h=.1):

    # given a list of times and gamma constants (rinfo for a specific vehicle = currinfo) as well as frames (t_n and T_nm1 for that specific vehicle) and the relaxation constant (rp). h is the timestep (.1 for NGSim)

    # we will make the relaxation amounts for the vehicle over the length of its trajectory

    # rinfo is precomputed in makeleadfolinfo_r. then during the objective evaluation/simulation, we compute these times.

    # note that we may need to alter the pre computed gammas inside of rinfo; that is because if you switch mutliple lanes in a short time, you may move to what looks like only a marginally shorter headway,

    # but really you are still experiencing the relaxation from the lane change you just took

    if len(currinfo) == 0:

        relax = jnp.zeros(T_n - frames[0] + 1)

        return relax, relax  # if currinfo is empty we don't have to do anything

    out = jnp.zeros(
        (T_n - frames[0] + 1,
         1))  # initialize relaxation amount for the time between t_n and T_n

    out2 = jnp.zeros((T_n - frames[0] + 1, 1))

    outlen = 1

    maxind = frames[1] - frames[
        0] + 1  # this is the maximum index we are supposed to put values into because the time between T_nm1 and T_n is not simulated. Plus 1 because of the way slices work.

    if rp < h:  # if relaxation is too small for some reason

        rp = h  # this is the smallest rp can be

    #    if rp<h: #if relaxation is smaller than the smallest it can be #deprecated

    #        return out, out2 #there will be no relaxation

    # mylen = math.ceil(

    #     rp / h) - 1  # this is how many nonzero entries will be in r each time we have the relaxation constant

    mylen = jnp.ceil(rp / h) - 1

    r = jnp.linspace(
        1 - h / rp, 1 - h / rp * (mylen), mylen
    )  # here are the relaxation constants. these are determined only by the relaxation constant. this gets multipled by the 'gamma' which is the change in headway immediately after the LC

    for i in range(
            len(currinfo)
    ):  # frames[1]-frames[0]+1 is the length of the simulation; this makes it so it will be all zeros between T_nm1 and T_n

        entry = currinfo[i]  # the current entry for the relaxation phenomenon

        curind = entry[0] - frames[
            0]  # current time is entry[0]; we start at frames[0] so this is the current index

        for j in range(outlen):

            if out2[curind, j] == 0:

                if curind + mylen > maxind:  # in this case we can't put in the entire r because we will get into the shifted end part (and also possibly get an index out of bounds error)

                    out[curind:maxind, j] = r[0:maxind - curind]

                    out2[curind:maxind, j] = currinfo[i][1]

                else:  # this is the normal case

                    out[curind:curind + mylen, j] = r

                    out2[curind:curind + mylen, j] = currinfo[i][1]

                break

        else:

            newout = jnp.zeros((T_n - frames[0] + 1, 1))

            newout2 = jnp.zeros((T_n - frames[0] + 1, 1))

            if curind + mylen > maxind:  # in this case we can't put in the entire r because we will get into the shifted end part (and also possibly get an index out of bounds error)

                newout[curind:maxind, 0] = r[0:maxind - curind]

                newout2[curind:maxind, 0] = currinfo[i][1]

            else:  # this is the normal case

                newout[curind:curind + mylen, 0] = r

                newout2[curind:curind + mylen, 0] = currinfo[i][1]

            out = jnp.append(out, newout, axis=1)

            out2 = jnp.append(out2, newout2, axis=1)

            outlen += 1

    #######calculate relaxation amounts and the part we need for the adjoint calculation #different from the old way

    relax = jnp.multiply(out, out2)

    relax = jnp.sum(relax, 1)

    if adj:

        outd = -(1 / rp) * (
            out - 1
        )  # derivative of out (note that this is technically not the derivative because of the piecewise nature of out/r)

        relaxadj = jnp.multiply(
            outd, out2
        )  # once multiplied with out2 (called gamma in paper) it will be the derivative though.

        relaxadj = jnp.sum(relaxadj, 1)

    else:

        relaxadj = relax

    return relax, relaxadj
Exemplo n.º 9
0
def rundmc(
    key,
    wf,
    configs,
    weights=None,
    tstep=0.01,
    nsteps=1000,
    branchtime=5,
    stepoffset=0,
    branchcut_start=3,
    branchcut_stop=6,
    drift_limiter=limdrift,
    verbose=False,
    accumulators=None,
    ekey=("energy", "total"),
    propagate=dmc_propagate,
    feedback=1.0,
    hdf_file=None,
    client=None,
    npartitions=None,
    **kwargs,
):
    """
    Run DMC 
    
    Args:
      wf: A Wave function-like class. recompute(), gradient(), and updateinternals() are used, as well as anything (such as laplacian() ) used by accumulators

      configs: (nconfig, nelec, 3) - initial coordinates to start calculation. 

      weights: (nconfig,) - initial weights to start calculation, defaults to uniform.

      nsteps: number of DMC steps to take

      tstep: Time step for move proposals. Introduces time step error.

      branchtime: number of steps to take between branching

      accumulators: A dictionary of functor objects that take in (coords,wf) and return a dictionary of quantities to be averaged. np.mean(quantity,axis=0) should give the average over configurations. If none, a default energy accumulator will be used.

      ekey: tuple of strings; energy is needed for DMC weights. Access total energy by accumulators[ekey[0]](configs, wf)[ekey[1]

      verbose: Print out step information 

      drift_limiter: a function that takes a gradient and a cutoff and returns an adjusted gradient

      stepoffset: If continuing a run, what to start the step numbering at.

    Returns: (df,coords,weights)
      df: A list of dictionaries nstep long that contains all results from the accumulators.

      coords: The final coordinates from this calculation.

      weights: The final weights from this calculation
      
    """
    # Restart from HDF file
    if hdf_file is not None and os.path.isfile(hdf_file):
        with h5py.File(hdf_file, "r") as hdf:
            stepoffset = hdf["step"][-1] + 1
            configs.load_hdf(hdf)
            weights = jnp.array(hdf["weights"])
            eref = hdf["eref"][-1]
            esigma = hdf["esigma"][-1]
            if verbose:
                print("Restarted calculation")
    else:
        warmup = 2
        key, subkey = jax.random.split(key)
        df, configs = mc.vmc(
            subkey,
            wf,
            configs,
            accumulators=accumulators,
            client=client,
            npartitions=npartitions,
            verbose=verbose,
        )
        en = df[ekey[0] + ekey[1]][warmup:]
        eref = jnp.mean(en).real
        esigma = jnp.sqrt(jnp.var(en) * jnp.mean(df["nconfig"]))
        if verbose:
            print("eref start", eref, "esigma", esigma)

    nconfig = configs.shape[0]
    if weights is None:
        weights = jnp.ones(nconfig)

    npropagate = int(jnp.ceil(nsteps / branchtime))
    df = []
    for step in range(npropagate):
        key, subkey = jax.random.split(key)
        df_, configs, weights = dmc_propagate(
            subkey,
            wf,
            configs,
            weights,
            tstep,
            branchcut_start * esigma,
            branchcut_stop * esigma,
            eref=eref,
            nsteps=branchtime,
            accumulators=accumulators,
            ekey=ekey,
            drift_limiter=drift_limiter,
            **kwargs,
        )

        df_["eref"] = eref
        df_["step"] = step + stepoffset
        df_["esigma"] = esigma
        df_["tstep"] = tstep
        df_["weight_std"] = jnp.std(weights)
        df_["nsteps"] = branchtime

        dmc_file(hdf_file, df_, {}, configs, weights)
        # print(df_)
        df.append(df_)
        eref = df_[ekey[0] + ekey[1]] - feedback * jnp.log(jnp.mean(weights))
        key, subkey = jax.random.split(key)
        configs, weights = branch(subkey, configs, weights)
        if verbose:
            print(
                "energy",
                df_[ekey[0] + ekey[1]],
                "eref",
                df_["eref"],
                "sigma(w)",
                df_["weight_std"],
            )

    df_ret = {}
    for k in df[0].keys():
        df_ret[k] = jnp.asarray([d[k] for d in df])
    return df_ret, configs, weights
Exemplo n.º 10
0
def process_from_disk(metadata, raw_frames, local_batch_size, filter_all,
                      filter_all_dexp, network_metadata):

    n_total_frames = raw_frames.shape[0]

    #if the batch size is not even in double exposure we fix that
    if local_batch_size % 2 != 0 and metadata['double_exposure']:
        local_batch_size += 1

    #If the batch size is not given or it is too big, we set it up to give work to every rank
    if local_batch_size == None or local_batch_size * mpi_size > n_total_frames:
        local_batch_size = n_total_frames // mpi_size

    batch_size = mpi_size * local_batch_size

    printv(
        color(
            "\r Using a local batch size per MPI rank = " +
            str(local_batch_size), bcolors.HEADER))

    #This stores the frames indexes that are being process by this mpi rank
    my_indexes = []
    n_batches = raw_frames.shape[0] // batch_size

    #Here we correct if the total number of frames is not a multiple of batch_size
    extra = raw_frames.shape[0] - (n_batches * batch_size)

    extra_last_batch = None
    if rank * local_batch_size < extra:
        n_batches = n_batches + 1
        n_ranks_extra = int(np.ceil(extra / local_batch_size))
        #We always overshot the batch sizes if they don't match perfectly (that is when extra % local_batch_size != 0)
        #To account for this, we need to have an index substraction (extra_last_batch) for last rank accross the ones having extra work
        if rank == n_ranks_extra - 1 and extra % local_batch_size != 0:
            extra_last_batch = -(local_batch_size -
                                 (extra % local_batch_size)) // (
                                     metadata['double_exposure'] + 1)

    n_out_frames = n_batches * local_batch_size // (
        metadata['double_exposure'] + 1)

    out_data_shape = (n_out_frames, metadata["output_frame_width"],
                      metadata["output_frame_width"])
    out_data = np.empty(out_data_shape, dtype=np.float32)
    frames_batch = npo.empty(
        (local_batch_size, raw_frames[0].shape[0], raw_frames[0].shape[1]))

    #Streaming variables
    frames_ready = 0
    frames_sent = 0
    streaming_output_buffer_size = 12
    output_socket = "intermediate_socket" in network_metadata

    for i in range(0, n_batches):

        local_i = ((i * batch_size) + (rank * local_batch_size))

        #we handle uneven shapes here
        upper_bound = min(local_i + local_batch_size, n_total_frames)

        local_range = range(local_i // (metadata['double_exposure'] + 1),
                            upper_bound // (metadata['double_exposure'] + 1))

        my_indexes.extend(local_range)

        i_s = i * local_batch_size // (metadata['double_exposure'] + 1)
        i_e = i_s + local_batch_size // (metadata['double_exposure'] + 1)

        for j in range(local_i, upper_bound):
            frames_batch[j % local_batch_size] = raw_frames[j][:, :]

        if metadata["double_exposure"]:
            centered_rescaled_frames_jax = filter_all_dexp(
                frames_batch[:-1:2], frames_batch[1::2])
        else:
            centered_rescaled_frames_jax = filter_all(frames_batch)

        # TODO: 'centered_rescaled_frames_jax' picks up an additional dimension somehow, should fix this...
        #out_data = jax.ops.index_update(out_data, jax.ops.index[i_s:i_e, :, :], centered_rescaled_frames_jax[:,0,:,:])
        out_data = out_data.at[i_s:i_e, :, :].set(
            centered_rescaled_frames_jax[:, 0, :, :])

        if rank == 0:
            sys.stdout.write(
                color("\r Computing batch = %s/%s " % (i + 1, n_batches),
                      bcolors.HEADER))
            sys.stdout.flush()

        frames_ready += (i_e - i_s)

        #Sending frames to socket
        if output_socket:

            if extra_last_batch is not None and i == n_batches - 1:
                i_e += extra_last_batch  #extra_last_batch is a negative offset, we add it here

            send_socket_data(out_data, my_indexes, i_s, i_e, network_metadata)

    if rank == 0: print("\n")
    return out_data[:extra_last_batch], my_indexes
Exemplo n.º 11
0
def _compute_range_weights(guide, grid_shape):
  """Computes range weights for the given guide image and grid shape.

  Args:
    guide: The guide image with shape (h, w).
    grid_shape: The grid shape, an array-like containing [gh, gw, gd, gc].

  Returns:
    An (image_extent, grid_extent) array with the spatial weight for each
    spatial and grid position.
  """
  guide_padded = _symmetric_pad_ij(guide, grid_shape)

  # Rescale `image` from [0, 1] to [0, grid_depth].
  # These are the floating point k coordinates of each sample.
  grid_depth = grid_shape[2]
  gk_float = guide_padded * grid_depth

  # Each sample with float value kf can splat onto locations:
  # k0 = floor(kf - 0.5)
  # k1 = ceil(kf - 0.5)
  #
  # The subtraction by 0.5 is necessary:
  # - Grid samples are located at half-integer coordinates:
  #   k = 0 places its sample at kf = 0.5.
  # - If kf = 1.4, the tent weight function is nonzero in the range [0.4, 1.4].
  #   Therefore, we need to splat to k0 = 0 and k1 = 1.
  # - If kf = 1.9, the tent weight function is nonzero in the range [0.9, 1.9].
  #   Therefore, we need to splat to k0 = 1 and k1 = 2.
  gk_floor = jnp.floor(gk_float - 0.5)
  gk_ceil = jnp.ceil(gk_float - 0.5)

  # Compute tent weights before clipping.
  wk_floor = smoothed_lerp_weight(gk_floor + 0.5, gk_float)
  wk_ceil = smoothed_lerp_weight(gk_ceil + 0.5, gk_float)

  # Cast to int for indexing.
  gk_floor = gk_floor.astype(jnp.int32)
  gk_ceil = gk_ceil.astype(jnp.int32)

  # Handle boundary conditions:
  # - Set the weight to 0 where the tent weight is positive but outside
  #   [0, grid_depth].
  # - Set the weight to 1 where the sample is between [0, 0.5) and
  #   (depth - 0.5, depth].
  wk_floor = jnp.where((gk_ceil == 0) & (gk_float < 0.5), 0, wk_floor)
  wk_ceil = jnp.where(
      (gk_floor == grid_depth - 1) & (gk_float > grid_depth - 0.5), 0, wk_ceil)
  wk_ceil = jnp.where((gk_ceil == 0) & (gk_float < 0.5), 1, wk_ceil)
  wk_floor = jnp.where(
      (gk_floor == grid_depth - 1) & (gk_float > grid_depth - 0.5), 1, wk_floor)

  # Now clip int coordinates for splatting. Coordinates outside [0, grid_depth)
  # will have zero weight so splatting to them does nothing.
  gk_floor_clipped = gk_floor.clip(0, grid_depth - 1)
  gk_ceil_clipped = gk_ceil.clip(0, grid_depth - 1)

  # Compute the i and j indices where we want to splat the weights wk with +=.
  # grid[ii, jj, gk_floor] += wk_floor
  # grid[ii, jj, gk_ceil] += wk_ceil
  ii, jj = jnp.meshgrid(
      jnp.arange(guide_padded.shape[0]),
      jnp.arange(guide_padded.shape[1]),
      indexing='ij')

  range_weights = jnp.zeros(
      (guide_padded.shape[0], guide_padded.shape[1], grid_depth))
  range_weights = jax.ops.index_add(range_weights,
                                    jax.ops.index[ii, jj,
                                                  gk_floor_clipped], wk_floor)
  range_weights = jax.ops.index_add(range_weights,
                                    jax.ops.index[ii, jj,
                                                  gk_ceil_clipped], wk_ceil)

  return range_weights