def _get_bbox_regression_labels(bbox_target_data, num_classes, bbox_in_weights): num_targets = tf.expand_dims(tf.range(tf.shape(bbox_target_data)[0]), 1) clss = tf.concat( [num_targets, tf.to_int32(tf.expand_dims(bbox_target_data[:, 0], 1))], 1) tgt_updates = bbox_target_data[:, 1:] bbox_targets = tf.tensor_scatter_update(tensor=tf.ones( [tf.shape(bbox_target_data)[0], num_classes, 4]), indices=clss, updates=tgt_updates) """ bbox_targets = tf.scatter_nd(indices=clss, updates=tgt_updates, shape=[tf.shape(bbox_target_data)[0], num_classes, 4]) """ weight_updates = tf.multiply(tf.ones_like(tgt_updates), bbox_in_weights) bbox_inside_weights = tf.tensor_scatter_update(tensor=tf.ones( [tf.shape(bbox_target_data)[0], num_classes, 4]), indices=clss, updates=weight_updates) """ bbox_inside_weights = tf.scatter_nd(indices=clss, updates=weight_updates, shape=[tf.shape(bbox_target_data)[0], num_classes, 4]) """ return bbox_targets, bbox_inside_weights
def update_step(mm, _tuple): mean, mean_squared = mm gvt, _env_id = _tuple _env_id = tf.reshape(_env_id, [1, 1]) # According to equation (6) in the PopArt-IMPALA paper # Matching the specific game with it's current vtrace corrected value estimate. first_moment = tf.reshape((1 - self._beta) * tf.gather(mean, _env_id) + self._beta * gvt, [1]) second_moment = tf.reshape((1 - self._beta) * tf.gather(mean_squared, _env_id) + self._beta * tf.square(gvt), [1]) # Matching the moments to the specific environment, so we only update the statistics for the specific game. n_mean = tf.tensor_scatter_update(mean, _env_id, first_moment) n_mean_squared = tf.tensor_scatter_update(mean_squared, _env_id, second_moment) return n_mean, n_mean_squared
def test_tensor_update(): zeros = tf.zeros([10, 300], tf.float32) # Update the 0-th row of features with all ones update_op = tf.tensor_scatter_update(zeros, 0, [1.0] * 300) init = tf.global_variables_initializer() with tf.Session() as session: session.run(init) session.run(update_op) print(zeros.eval())
def _rejection_sample_eager(values, row_splits): assert (tf.executing_eagerly()) out = [] N = tf.size(row_splits) - 1 consumed = tf.zeros((N,), dtype=tf.bool) for i in range(N): if not consumed[i]: out.append(i) vals = tf.expand_dims(values[row_splits[i]:row_splits[i + 1]], axis=-1) consumed = tf.tensor_scatter_update( consumed, vals, tf.ones(shape=(tf.size(vals),), dtype=tf.bool)) return tf.stack(out, axis=0)
def apply_masking(token_ids, target_token_ids, mask_indices, mask_token_id, vocab_size): """Applies BERT masking. Args: token_ids (Tensor): 1-D Tensor of token IDs (ints) target_token_ids (Tensor): 1-D Tensor of token IDs (ints) mask_indices (Tensor): 1-D Tensor of indices (ints) mask_token_id (int): ID of [MASK] token. vocab_size (int): total size of vocabulary. Returns: token_ids_masked (Tensor): 1-D Tensor of token IDs, after target positions have been replaced with [MASK], a random token, or left alone. target_token_ids (Tensor): the original token IDs at the target positions. """ num_to_mask = tf.size(mask_indices) mask_token_ids = tf.fill([num_to_mask], tf.cast(mask_token_id, tf.int64)) random_token_ids = tf.random.uniform([num_to_mask], minval=0, maxval=vocab_size, dtype=tf.int64) # Uniform [0, 1) floats. randomness = tf.random.uniform([num_to_mask]) # Replace target tokens with mask tokens. mask_values = tf.where(randomness < 0.8, mask_token_ids, target_token_ids) # Replace target tokens with random tokens. mask_values = tf.where(randomness > 0.9, random_token_ids, mask_values) # Mask out token_ids at mask_indices. token_ids_masked = tf.tensor_scatter_update(token_ids, mask_indices[:, None], mask_values) return token_ids_masked
def topk_sgd(W, k): """ Top-K SGD sparsification with memory """ W_shape = W.get_shape().as_list() W_size = np.prod(W_shape) k = min(np.prod(W_shape), k) w = tf.reshape(W, shape=(-1,)) residue = tf.Variable(tf.zeros(shape=(W_size,)), dtype=tf.float32, trainable=False) x = w + residue _, indices = tf.math.top_k(tf.abs(x), k, sorted=False) new_residue = tf.tensor_scatter_update(x, tf.expand_dims(indices, 1), tf.zeros(k, tf.float32)) xh = x - new_residue Wh = tf.reshape(xh, W_shape) with tf.control_dependencies([Wh, new_residue]): update_residue = residue.assign(new_residue) return Wh, update_residue, residue
def _subsample_labels(tensor, size): indices = tf.slice(tf.where(tf.equal(tensor, 1)), [0, 0], [-1, 1]) indices = random_choice(indices, size, replace=False) updates = tf.zeros([size, 1], tf.int32) return tf.tensor_scatter_update(tensor, indices, updates)
def _swap_dims(tensor, dim_1, dim_2): updates = tf.gather(tensor, indices=[dim_1, dim_2]) return tf.tensor_scatter_update(tensor, indices=[[dim_2], [dim_1]], updates=updates)
def draw(self, vertices, color=None): X1 = tf32(vertices[0].position.x) X2 = tf32(vertices[1].position.x) X3 = tf32(vertices[2].position.x) Y1 = tf32(vertices[0].position.y) Y2 = tf32(vertices[1].position.y) Y3 = tf32(vertices[2].position.y) U1 = tf32(vertices[0].uv.x) U2 = tf32(vertices[1].uv.x) U3 = tf32(vertices[2].uv.x) V1 = tf32(vertices[0].uv.y) V2 = tf32(vertices[1].uv.y) V3 = tf32(vertices[2].uv.y) W = ti32(self.width) H = ti32(self.height) # # 28.4 fixed-point coordinates Y1 = fix28_4((0.5 * Y1 + 0.5) * tf32(H) - 0.5) Y2 = fix28_4((0.5 * Y2 + 0.5) * tf32(H) - 0.5) Y3 = fix28_4((0.5 * Y3 + 0.5) * tf32(H) - 0.5) X1 = fix28_4((0.5 * X1 + 0.5) * tf32(W) - 0.5) X2 = fix28_4((0.5 * X2 + 0.5) * tf32(W) - 0.5) X3 = fix28_4((0.5 * X3 + 0.5) * tf32(W) - 0.5) # # Bounding rectangle # minx = main.rsh((main.min3(X1, X2, X3) + 0xF), 4) # maxx = main.rsh((main.max3(X1, X2, X3) + 0xF), 4) # miny = main.rsh((main.min3(Y1, Y2, Y3) + 0xF), 4) # maxy = main.rsh((main.max3(Y1, Y2, Y3) + 0xF), 4) # minx = main.max2(minx, 0) # maxx = main.min2(maxx, W) # miny = main.max2(miny, 0) # maxy = main.min2(maxy, H) # calculate edges. EX1 = main.f32(X2 - X1) / 16.0 EX2 = main.f32(X3 - X1) / 16.0 EY1 = main.f32(Y2 - Y1) / 16.0 EY2 = main.f32(Y3 - Y1) / 16.0 # Deltas DX12 = X1 - X2 DX23 = X2 - X3 DX31 = X3 - X1 DY12 = Y1 - Y2 DY23 = Y2 - Y3 DY31 = Y3 - Y1 # Half-edge constants C1 = DY12 * X1 - DX12 * Y1 C2 = DY23 * X2 - DX23 * Y2 C3 = DY31 * X3 - DX31 * Y3 # check the "outside" triangle point against the first edge. If # the triangle is backfacing, this will reverse the face. flip = main.sign(C1 + DX12 * Y3 - DY12 * X3) DX12 *= flip DX23 *= flip DX31 *= flip DY12 *= flip DY23 *= flip DY31 *= flip C1 *= flip C2 *= flip C3 *= flip # Correct for fill convention C1 += main.i32( tf.logical_or( tf.less(DY12, 0), tf.logical_and(tf.equal(DY12, 0), tf.greater(DX12, 0)))) C2 += main.i32( tf.logical_or( tf.less(DY23, 0), tf.logical_and(tf.equal(DY23, 0), tf.greater(DX23, 0)))) C3 += main.i32( tf.logical_or( tf.less(DY31, 0), tf.logical_and(tf.equal(DY31, 0), tf.greater(DX31, 0)))) # http://www.lysator.liu.se/~mikaelk/doc/perspectivetexture/dcdx.txt # (c3 - c1) * (y2 - y1) - (c2 - c1) * (y3 - y1) # dc/dx = --------------------------------------------- # (x3 - x1) * (y2 - y1) - (x2 - x1) * (y3 - y1) # (c2 - c1) * (x3 - x1) - (c3 - c1) * (x2 - x1) # dc/dy = --------------------------------------------- # (x3 - x1) * (y2 - y1) - (x2 - x1) * (y3 - y1) EU1 = U2 - U1 EU2 = U3 - U1 EV1 = V2 - V1 EV2 = V3 - V1 DUDX = EU2 * EY1 - EU1 * EY2 DVDX = EV2 * EY1 - EV1 * EY2 DUDY = EU1 * EX2 - EU2 * EX1 DVDY = EV1 * EX2 - EV2 * EX1 DETR = EX2 * EY1 - EX1 * EY2 DUDX /= DETR DVDX /= DETR DUDY /= DETR DVDY /= DETR # DUDX = (U3 - U1) * (Y2 - Y1) - (U2 - U1) * (Y3 - Y1) # DVDX = (V3 - V1) * (Y2 - Y1) - (V2 - V1) * (Y3 - Y1) # DUDY = (U2 - U1) * (X3 - X1) - (U3 - U1) * (X2 - X1) # DVDY = (V2 - V1) * (X3 - X1) - (V3 - V1) * (X2 - X1) # DETR = (X3 - X1) * (Y2 - Y1) - (X2 - X1) * (Y3 - Y1) # DETR = main.f32(DETR) # DUDX /= DETR / 16.0 # DVDX /= DETR / 16.0 # DUDY /= DETR / 16.0 # DVDY /= DETR / 16.0 color = tf.fill([H, W, 4], 128.0) if color is None else color x, y = tf.meshgrid(tf.range(0, W), tf.range(0, H)) p = tf.stack([y, x], 2) h, w, _ = p.shape.as_list() valid1 = C1 + DX12 * y * 16 - DY12 * x * 16 valid2 = C2 + DX23 * y * 16 - DY23 * x * 16 valid3 = C3 + DX31 * y * 16 - DY31 * x * 16 #valid = tf.greater(main.min3(main.sign(valid1), main.sign(valid2), main.sign(valid3)), 0) valid = tf.equal( main.sign(valid1) + main.sign(valid2) + main.sign(valid3), 3) # Calculate UV at screen origin. U = U1 - DUDX * tf32(X1) / 16 - DUDY * tf32(Y1) / 16 V = V1 - DVDX * tf32(X1) / 16 - DVDY * tf32(Y1) / 16 u = tf.cumsum(tf.fill([h, w], DUDX), 1) u += tf.cumsum(tf.fill([h, w], DUDY), 0) u += U v = tf.cumsum(tf.fill([h, w], DVDX), 1) v += tf.cumsum(tf.fill([h, w], DVDY), 0) v += V uv = tf.stack([v, u], 2) uv = tf.boolean_mask(uv, valid) p = tf.boolean_mask(p, valid) tex0 = tf.cast(self.tex0, tf.float32) th, tw = self.tex0.shape[0:2] R, G, B = tf.unstack(tf.gather_nd( tex0, main.clamp(main.iround(main.wrap(uv, "reflect") * [tw, th]), 0, [tw - 1, th - 1])), axis=-1) A = tf.ones_like(R) * 255 frag_color = tf.stack([R, G, B, A], 1) color = tf.tensor_scatter_update(color, p, frag_color) color = main.clamp(color, 0.0, 255.0) return color
def score_cooccurance(tf_a1): import tensorflow as tf co_occurance_threshold = 2 exclude_common_words = 1000 N = 50 p = (tf_a1 + tf.abs(tf_a1)) / 2 input_tf = tf.concat([p, tf.zeros((1, p.shape[1]), p.dtype)], axis=0) tf_a2 = tf.sort(sent_wids, axis=1) first_col_change = tf.zeros([tf_a2.shape[0], 1], dtype=tf.int32) last_cols_change = tf.cast(tf.equal(tf_a2[:, 1:], tf_a2[:, :-1]), tf.int32) change_bool = tf.concat([first_col_change, last_cols_change], axis=-1) not_change_bool = 1 - change_bool tf_a2_changed = tf_a2 * not_change_bool + change_bool * N #here #this part is only for selecting word indexes which are not very common y, idx, count = tf.unique_with_counts(tf.reshape(tf_a2_changed, [ -1, ])) count_mask = tf.reshape(tf.gather(count, idx), tf_a2_changed.shape) result = tf.where( tf.logical_and(tf.less(count_mask, exclude_common_words), tf.not_equal(tf_a2_changed, N)), tf_a2_changed, tf.math.negative(tf.ones_like(tf_a2_changed))) idx = tf.where( tf.count_nonzero(tf.gather(input_tf, result, axis=0), axis=1) >= co_occurance_threshold) y, x = idx[:, 0], idx[:, 1] rows_tf = tf.gather(result, y, axis=0) columns_tf = tf.cast(x[:, None], tf.int32) out = tf.zeros(shape=tf.shape(p), dtype=tf.float32) rows_tf = tf.reshape(rows_tf, shape=[-1, 1]) columns_tf = tf.reshape(tf.tile(columns_tf, multiples=[1, tf.shape(result)[1]]), shape=[-1, 1]) sparse_indices = tf.reshape(tf.concat([rows_tf, columns_tf], axis=-1), shape=[-1, 2]) v = tf.gather_nd(input_tf, sparse_indices) v = tf.reshape(v, [-1, tf.shape(result)[1]]) p_good_rows = tf.tensor_scatter_update(out, tf.cast(sparse_indices, tf.int32), tf.reshape(v, shape=[-1])) p_sum_not_in_goodrows = tf.reduce_sum(p - p_good_rows) number_of_good_items = tf.where(p_good_rows) enegy = p_sum_not_in_goodrows / tf.cast( tf.shape(number_of_good_items)[0] + 1, dtype=tf.float32) energy_matrice = tf.scatter_nd( tf.cast(number_of_good_items, tf.int32), enegy * tf.ones(shape=(tf.shape(number_of_good_items)[0])), shape=tf.shape(p)) result_p = p_good_rows + energy_matrice ####story for the negative weights n = (tf_a1 - tf.abs(tf_a1)) / 2 input_tf_n = tf.concat([n, tf.zeros((1, n.shape[1]), n.dtype)], axis=0) idx_n = tf.where( tf.count_nonzero(tf.gather(input_tf_n, result, axis=0), axis=1) >= co_occurance_threshold) y_n, x_n = idx_n[:, 0], idx_n[:, 1] rows_tf_n = tf.gather(result, y_n, axis=0) columns_tf_n = tf.cast(x_n[:, None], tf.int32) out_n = tf.zeros(shape=tf.shape(n), dtype=tf.float32) rows_tf_n = tf.reshape(rows_tf_n, shape=[-1, 1]) columns_tf_n = tf.reshape(tf.tile(columns_tf_n, multiples=[1, tf.shape(result)[1]]), shape=[-1, 1]) sparse_indices_n = tf.reshape(tf.concat([rows_tf_n, columns_tf_n], axis=-1), shape=[-1, 2]) v_n = tf.gather_nd(input_tf_n, sparse_indices_n) v_n = tf.reshape(v_n, [-1, tf.shape(result)[1]]) n_good_rows = tf.tensor_scatter_update(out_n, tf.cast(sparse_indices_n, tf.int32), tf.reshape(v_n, shape=[-1])) n_sum_not_in_goodrows = tf.reduce_sum(n - n_good_rows) n_number_of_good_items = tf.where(n_good_rows) enegy_n = n_sum_not_in_goodrows / tf.cast( tf.shape(n_number_of_good_items)[0] + 1, dtype=tf.float32) energy_matrice_n = tf.scatter_nd( tf.cast(n_number_of_good_items, tf.int32), enegy_n * tf.ones(shape=(tf.shape(n_number_of_good_items)[0])), shape=tf.shape(n)) result_n = n_good_rows + energy_matrice_n res = result_p + result_n return res
def _grow_alive_seq(self, state): """Grow alive sequences by one token, and collect top 2*beam_size sequences. 2*beam_size sequences are collected because some sequences may have reached the EOS token. 2*beam_size ensures that at least beam_size sequences are still alive. Args: state: A dictionary with the current loop state. Returns: Tuple of (Top 2*beam_size sequences [batch_size, 2 * beam_size, cur_index + 1], Scores of returned sequences [batch_size, 2 * beam_size], New alive cache, for each of the 2 * beam_size sequences) """ i = state[_StateKeys.CUR_INDEX] alive_seq = state[_StateKeys.ALIVE_SEQ] alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS] alive_cache = state[_StateKeys.ALIVE_CACHE] beams_to_keep = 2 * self.beam_size # Get logits for the next candidate IDs for the alive sequences. Get the new # cache values at the same time. if self.padded_decode: flat_ids = tf.reshape( tf.slice(alive_seq, [0, 0, i], [self.batch_size, self.beam_size, 1]), [self.batch_size * self.beam_size, -1]) else: flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size] flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache) flat_logits, flat_cache = self.symbols_to_logits_fn( flat_ids, i, flat_cache) # Unflatten logits to shape [batch_size, beam_size, vocab_size] logits = _unflatten_beam_dim(flat_logits, self.batch_size, self.beam_size) new_cache = nest.map_structure( lambda t: _unflatten_beam_dim(t, self.batch_size, self.beam_size), flat_cache) # Convert logits to normalized log probs candidate_log_probs = _log_prob_from_logits(logits) # Calculate new log probabilities if each of the alive sequences were # extended # by the the candidate IDs. # Shape [batch_size, beam_size, vocab_size] log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) # Each batch item has beam_size * vocab_size candidate sequences. For each # batch item, get the k candidates with the highest log probabilities. flat_log_probs = tf.reshape(log_probs, [-1, self.beam_size * self.vocab_size]) topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep) # Extract the alive sequences that generate the highest log probabilities # after being extended. topk_beam_indices = topk_indices // self.vocab_size topk_seq, new_cache = _gather_beams([alive_seq, new_cache], topk_beam_indices, self.batch_size, beams_to_keep) # Append the most probable IDs to the topk sequences topk_ids = topk_indices % self.vocab_size if self.padded_decode: topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1]) topk_seq = tf.tensor_scatter_update(topk_seq, [i + 1], topk_ids) topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0]) else: topk_ids = tf.expand_dims(topk_ids, axis=2) topk_seq = tf.concat([topk_seq, topk_ids], axis=2) return topk_seq, topk_log_probs, new_cache
def generate_proposal_target(rpn_rois, rpn_scores, tgt_boxes, num_classes, scope=None): """ Generate target proposal :param rpn_rois: Region of Interest calculated by rpn :param rpn_scores: Scores of rois :param tgt_boxes: Target bounding box :param num_classes: Number of class for classification :param scope :return: labels, rois, roi_scores, bbox_targets, bbox_inside_weights, bbox_outside_weights """ check_invoke() global params with tf.name_scope(scope, "gen_proposal_tgt"): all_rois = rpn_rois all_scores = rpn_scores if params.use_tgt: tf.concat([ all_rois, tf.concat( [tf.zeros([tf.shape(tgt_boxes)[0], 1]), tgt_boxes[:, :-1]], axis=1) ], axis=0) overlaps = boxes_iou(tf.slice(all_rois, [0, 1], [-1, -1]), tgt_boxes[:, :-1]) max_indices = tf.squeeze( tf.argmax(overlaps, axis=0, output_type=tf.int32)) max_overlaps = tf.gather_nd( params=tf.transpose(overlaps, perm=[1, 0, 2]), indices=tf.concat([ tf.expand_dims( tf.range(tf.shape(max_indices)[0], dtype=tf.int32), 1), tf.expand_dims(max_indices, 1) ], 1)) labels = tf.gather(tgt_boxes, max_indices)[:, 4] fg_thresh = params.fg_thresh bg_thresh_range = params.bg_thresh_range fg_masks = tf.cast(tf.greater_equal(max_overlaps, fg_thresh), tf.int32) bg_masks = tf.cast( tf.greater_equal(max_overlaps, bg_thresh_range[0]) & tf.less(max_overlaps, bg_thresh_range[1]), tf.int32) num_fg = tf.reduce_sum(fg_masks) num_bg = tf.reduce_sum(bg_masks) rois_per_image = params.roi_batch_size / params.image_batch_size fg_rois_per_image = fg_thresh * rois_per_image rois_per_image = tf.cast(rois_per_image, tf.int32) fg_rois_per_image = tf.cast(fg_rois_per_image, tf.int32) fg_rois_per_image = tf.minimum(num_fg, fg_rois_per_image) bg_rois_per_image = num_bg to_replace, fg_rois_per_image, bg_rois_per_image = tf.case( { tf.greater(fg_rois_per_image, 0) & tf.equal( bg_rois_per_image, 0): lambda: [ tf.logical_not( tf.greater(rois_per_image, fg_rois_per_image)), rois_per_image, 0 ], tf.equal(fg_rois_per_image, 0) & tf.greater( bg_rois_per_image, 0): lambda: [ tf.greater(rois_per_image, bg_rois_per_image), 0, rois_per_image ], tf.greater(fg_rois_per_image, 0) & tf.greater( bg_rois_per_image, 0): lambda: [ tf.greater(rois_per_image - fg_rois_per_image, num_bg), fg_rois_per_image, rois_per_image - fg_rois_per_image ] }, default=lambda: [tf.cast(False, tf.bool), 0, 0], exclusive=True) fg_indices = _sample_indices(fg_masks, fg_rois_per_image, tf.logical_not(to_replace)) bg_indices = _sample_indices(bg_masks, bg_rois_per_image, to_replace) keep_indices = tf.concat([fg_indices, bg_indices], axis=0) fg_labels = tf.to_int32(tf.gather(labels, indices=fg_indices)) labels = tf.tensor_scatter_update(tensor=tf.zeros([rois_per_image], dtype=tf.int32), indices=tf.expand_dims( tf.range(fg_rois_per_image), 1), updates=fg_labels) """ labels = tf.scatter_nd(indices=tf.expand_dims(tf.range(fg_rois_per_image), 1), updates=fg_labels, shape=[rois_per_image]) """ rois = tf.gather(all_rois, indices=keep_indices) roi_scores = tf.gather(all_scores, keep_indices) bbox_target_data = _compute_target( ex_rois=rois[:, 1:5], tgt_rois=tf.gather(tgt_boxes, tf.gather(max_indices, keep_indices))[:, :4], means=params.bbox_norm_means, stddevs=params.bbox_norm_stddevs, pre_norm=params.bbox_pre_norm, labels=labels) bbox_targets, bbox_inside_weights = _get_bbox_regression_labels( bbox_target_data, num_classes, params.bbox_in_weights) bbox_inside_weights = tf.reshape(bbox_inside_weights, [-1, 4 * num_classes]) bbox_outside_weights = tf.cast(tf.greater(bbox_inside_weights, 0), tf.float32) bbox_targets = tf.reshape(bbox_targets, [-1, 4 * num_classes]) rois = tf.reshape(rois, [-1, 5]) roi_scores = tf.reshape(roi_scores, [-1]) return tf.to_int64( labels ), rois, roi_scores, bbox_targets, bbox_inside_weights, bbox_outside_weights
[p[1], n2 - (n1+p[1])]], constant_values=1) padded_imag = tf.pad(tf.imag(tf_obj_1), [[p[0], n2 - (n1+p[0])], [p[1], n2 - (n1+p[1])]], constant_values=0) tf_obj_real_pads.append(padded_real) tf_obj_imag_pads.append(padded_imag) tf_obj_real_pads = tf.stack(tf_obj_real_pads) tf_obj_imag_pads = tf.stack(tf_obj_imag_pads) tf_obj_pads = tf.complex(tf_obj_real_pads, tf_obj_imag_pads) batch_obj_views = tf.gather(tf_obj_pads, batch_indices) batch_view_indices = tf.gather(tf_view_indices, batch_indices) gen_view_fn = lambda view: tf.tensor_scatter_update(tensor=view[0], indices=tf.reshape(view[1], [-1,1]), updates=tf.reshape(tf_obj_1, [-1])) batch_obj_views = tf.map_fn(fn=gen_view_fn, elems=[batch_views, batch_view_indices], dtype='complex64') batch_obj_views = tf.reshape(batch_obj_views, [batch_size, *probe_init.shape]) batch_view_indices = tf.gather(tf_view_indices, batch_indices) gen_view_real_fn = lambda view_indices: tf.scatter_nd(tf.reshape(view_indices, [-1,1]), tf.reshape(tf_obj_real_pad -1, [-1]), [probe_init.size]) + 1 gen_view_imag_fn = lambda view_indices: tf.scatter_nd(tf.reshape(view_indices, [-1,1]), tf.reshape(tf_obj_imag_pad, [-1]), [probe_init.size]) batch_obj_real_views = tf.map_fn(gen_view_real_fn, batch_view_indices, dtype=tf.float32) batch_obj_imag_views = tf.map_fn(gen_view_imag_fn, batch_view_indices, dtype=tf.float32)