def get_random_scale(min_scale_factor, max_scale_factor, step_size): """Gets a random scale value. Args: min_scale_factor: Minimum scale value. max_scale_factor: Maximum scale value. step_size: The step size from minimum to maximum value. Returns: A random scale value selected between minimum and maximum value. Raises: ValueError: min_scale_factor has unexpected value. """ if min_scale_factor < 0 or min_scale_factor > max_scale_factor: raise ValueError('Unexpected value of min_scale_factor.') if min_scale_factor == max_scale_factor: return tf.to_float(min_scale_factor) # When step_size = 0, we sample the value uniformly from [min, max). if step_size == 0: return tf.random_uniform([1], minval=min_scale_factor, maxval=max_scale_factor) # When step_size != 0, we randomly select one discrete value from [min, max]. num_steps = int((max_scale_factor - min_scale_factor) / step_size + 1) scale_factors = tf.lin_space(min_scale_factor, max_scale_factor, num_steps) shuffled_scale_factors = tf.random_shuffle(scale_factors) return shuffled_scale_factors[0]
def subsample_indicator(indicator, num_samples): """Subsample indicator vector. Given a boolean indicator vector with M elements set to `True`, the function assigns all but `num_samples` of these previously `True` elements to `False`. If `num_samples` is greater than M, the original indicator vector is returned. Args: indicator: a 1-dimensional boolean tensor indicating which elements are allowed to be sampled and which are not. num_samples: int32 scalar tensor Returns: a boolean tensor with the same shape as input (indicator) tensor """ indices = tf.where(indicator) indices = tf.random_shuffle(indices) indices = tf.reshape(indices, [-1]) num_samples = tf.minimum(tf.size(indices), num_samples) selected_indices = tf.slice(indices, [0], tf.reshape(num_samples, [1])) selected_indicator = ops.indices_to_dense_vector( selected_indices, tf.shape(indicator)[0]) return tf.equal(selected_indicator, 1)
def _call(self, inputs): ids, num_samples = inputs print("Uniform Neighbour sampler {}".format(num_samples)) adj_lists = tf.nn.embedding_lookup(self.adj_info, ids) adj_lists = tf.transpose(tf.random_shuffle(tf.transpose(adj_lists))) adj_lists = tf.slice(adj_lists, [0, 0], [-1, num_samples]) return adj_lists
def get_random_box_center(boxes, image_height, image_width): """ Arguments: boxes: a float tensor with shape [num_persons, 4]. image_height, image_width: float tensors with shape []. Returns: box_center: a float tensor with shape [1, 2]. box_width: a float tensor with shape []. """ # get a random bounding box box = tf.random_shuffle(boxes)[0] # it has shape [4] ymin, xmin, ymax, xmax = tf.unstack(box) box_height, box_width = ymax - ymin, xmax - xmin # get the center of the box cy = ymin + 0.5 * box_height cx = xmin + 0.5 * box_width # we will rotate around the box's center, # but the center mustn't be too near to the border of the image cy = tf.clip_by_value(cy, 0.25 * image_height, 0.75 * image_height) cx = tf.clip_by_value(cx, 0.2 * image_width, 0.8 * image_width) box_center = tf.stack([cy, cx]) box_center = tf.reshape(box_center, [1, 2]) return box_center, box_width
def _call(self, inputs): ids, num_samples = inputs adj_lists = tf.nn.embedding_lookup(self.adj_info, ids) adj_lists = tf.transpose(tf.random_shuffle(tf.transpose(adj_lists))) # adj_lists = tf.transpose(tf.random.shuffle(tf.transpose(adj_lists))) adj_lists = tf.slice(adj_lists, [0, 0], [-1, num_samples]) return adj_lists
def disable_some_fgs(): # We want to delete a randomly-selected subset of fg_inds of # size `fg_inds.shape[0] - max_fg`. # We shuffle along the dimension 0 and then we get the first # num_fg_inds - max_fg indices and we disable them. shuffled_inds = tf.random_shuffle(fg_inds, seed=self._seed) disable_place = (tf.shape(fg_inds)[0] - max_fg) # This function should never run if num_fg_inds <= max_fg, so we # add an assertion to catch the wrong behaviour if it happens. integrity_assertion = tf.assert_positive( disable_place, message="disable_place in disable_some_fgs is negative.") with tf.control_dependencies([integrity_assertion]): disable_inds = shuffled_inds[:disable_place] is_disabled = tf.sparse_to_dense( sparse_indices=disable_inds, sparse_values=True, default_value=False, output_shape=tf.cast(proposals_label_shape, tf.int64), # We are shuffling the indices, so they may not be ordered. validate_indices=False) return tf.where( condition=is_disabled, # We set it to -label for debugging purposes. x=tf.negative(proposals_label), y=proposals_label)
def MiNetwork(x_in, y_in): H = 10 seed = np.random.randint(0, 1000, 1) y_shuffle = tf.gather( y_in, tf.random_shuffle(tf.range(tf.shape(y_in)[0]), seed=seed)) x_conc = tf.concat([x_in, x_in], axis=0) y_conc = tf.concat([y_in, y_shuffle], axis=0) # propagate the forward pass layerx = tf.layers.conv2d(x_conc, 16, 2, (1, 1), use_bias=True, name='M_0') layerx = tf.layers.flatten(layerx, name='M_1') layerx = tf.layers.dense(layerx, 512, name='M_2', use_bias=True) layerx = tf.layers.dense(layerx, H, name='M_3', use_bias=True) #======================================== layery = tf.layers.conv2d(y_conc, 16, 2, (1, 1), name='M_4', use_bias=True) layery = tf.layers.flatten(layery, name='M_5') layery = tf.layers.dense(layery, 512, name='M_6', use_bias=True) layery = tf.layers.dense(layery, H, name='M_7', use_bias=True) layer2 = tf.nn.relu(layerx + layery, name='M_8') output = tf.layers.dense(layer2, 1, name='M_9', use_bias=False) # split in T_xy and T_x_y predictions N_samples = tf.shape(x_in)[0] T_xy = output[:N_samples] T_x_y = output[N_samples:] return T_xy, T_x_y
def sample_msa(protein, max_seq, keep_extra): """Sample MSA randomly, remaining sequences are stored as `extra_*`. Args: protein: batch to sample msa from. max_seq: number of sequences to sample. keep_extra: When True sequences not sampled are put into fields starting with 'extra_*'. Returns: Protein with sampled msa. """ num_seq = tf.shape(protein['msa'])[0] shuffled = tf.random_shuffle(tf.range(1, num_seq)) index_order = tf.concat([[0], shuffled], axis=0) num_sel = tf.minimum(max_seq, num_seq) sel_seq, not_sel_seq = tf.split(index_order, [num_sel, num_seq - num_sel]) for k in _MSA_FEATURE_NAMES: if k in protein: if keep_extra: protein['extra_' + k] = tf.gather(protein[k], not_sel_seq) protein[k] = tf.gather(protein[k], sel_seq) return protein
def random_pivots( inputs, num_pivots, sample_size=SAMPLE_SIZE, seed=None, scope=None ): """Pivots initialization function. Initial pivots are determined as centers of `sample_size` randomly sampled points. Args: inputs: A float32 `Tensor` of training data. num_pivots: Number of pivots to initialize. sample_size: Number of points to sample for each pivot. seed: Optional random seed. scope: Optional variable scope. Returns: 1-D float32 `Tensor` of pivots. """ with tf.variable_scope(scope, 'RandomPivots', [inputs]): num_inputs = tf.shape(inputs)[0] indices = tf.tile(tf.range(num_inputs), [num_pivots]) indices = tf.random_shuffle(indices, seed=seed) indices = tf.gather(indices, tf.range(num_pivots * sample_size)) samples = tf.gather(inputs, indices) samples = tf.reshape(samples, [sample_size, num_pivots, -1]) pivots = tf.reduce_sum(samples, 0) / sample_size return pivots
def scheduled_sample_count(ground_truth_x, generated_x, batch_size, scheduled_sample_var): """Sample batch with specified mix of groundtruth and generated data points. Args: ground_truth_x: tensor of ground-truth data points. generated_x: tensor of generated data points. batch_size: batch size scheduled_sample_var: number of ground-truth examples to include in batch. Returns: New batch with num_ground_truth sampled from ground_truth_x and the rest from generated_x. """ num_ground_truth = scheduled_sample_var idx = tf.random_shuffle(tf.range(batch_size)) ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) generated_idx = tf.gather(idx, tf.range(num_ground_truth, batch_size)) ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) generated_examps = tf.gather(generated_x, generated_idx) output = tf.dynamic_stitch([ground_truth_idx, generated_idx], [ground_truth_examps, generated_examps]) # if batch size is known set it. if isinstance(batch_size, int): output.set_shape([batch_size] + common.shape_list(output)[1:]) return output
def _face_model_map_fn(example): vertices = example['vertices'] # Randomly shift vertices if apply_random_shift: vertices = random_shift(vertices) example['num_vertices'] = tf.shape(vertices)[0] # Optionally shuffle vertices and re-order faces to match if shuffle_vertices: permutation = tf.random_shuffle(tf.range(example['num_vertices'])) vertices = tf.gather(vertices, permutation) face_permutation = tf.concat([ tf.constant([0, 1], dtype=tf.int32), tf.argsort(permutation) + 2 ], axis=0) example['faces'] = tf.cast( tf.gather(face_permutation, example['faces']), tf.int64) # Vertices are quantized. So convert to floats for input to face model example['vertices'] = modules.dequantize_verts(vertices, quantization_bits) example['vertices_mask'] = tf.ones_like(example['vertices'][Ellipsis, 0], dtype=tf.float32) example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32) return example
def next_production_rule_info_batch_text_summary(expression_strings, partial_sequences, partial_sequence_lengths, next_production_rules, unmasked_probabilities_batch, masked_probabilities_batch, grammar, target_length=None): """Ceates text summary for a batch next production rule prediction. Args: expression_strings: String tensor with shape [batch_size]. partial_sequences: Integer tensor with shape [batch_size, max_length]. partial_sequence_lengths: Integer tensor with shape [batch_size]. next_production_rules: Integer tensor with shape [batch_size]. The indice of the next production rules. unmasked_probabilities_batch: Float tensor with shape [batch_size, num_production_rules]. The probabilities from the model prediction without valid production rule mask. masked_probabilities_batch: Boolean tensor with shape [batch_size, num_production_rules]. The probabilities from the model prediction after applied valid production rule mask. grammar: arithmetic_grammar.Grammar object. target_length: Integer. Only examples with partial sequence length equal to target_length will be used. If None (the default), all examples in batch will be used. Returns: summary: String Tensor containing a Summary proto. update_op: Op that updates summary (and the underlying stream). """ if target_length is not None: (expression_strings, partial_sequences, partial_sequence_lengths, next_production_rules, unmasked_probabilities_batch, masked_probabilities_batch) = mask_by_partial_sequence_length( tensors=(expression_strings, partial_sequences, partial_sequence_lengths, next_production_rules, unmasked_probabilities_batch, masked_probabilities_batch), partial_sequence_lengths=partial_sequence_lengths, target_length=target_length) suffix = '/length_%d' % target_length else: suffix = '' info = tf.py_func( functools.partial(next_production_rule_info_batch, grammar=grammar), [ expression_strings, partial_sequences, partial_sequence_lengths, next_production_rules, unmasked_probabilities_batch, masked_probabilities_batch ], tf.string, name='py_func-next_production_rule_info_batch_text_summary' + suffix) info.set_shape([expression_strings.shape[0]]) value, update_op = contrib_metrics.streaming_concat(info) value = tf.random_shuffle(value) # So we see different summaries. summary = tf.summary.text('next_production_rule_info' + suffix, value[:10]) return summary, update_op
def fit(self, data_set: DataSet): # generic parameter checks super().fit(data_set) self._num_labels = len(data_set.label_map) graph = tf.Graph() with graph.as_default(): tf_inputs = tf.Variable(initial_value=data_set.features, trainable=False, dtype=tf.float32) tf_labels = tf.Variable(initial_value=data_set.labels_numeric, trainable=False, dtype=tf.int32) if self._shuffle_training: tf_inputs = tf.random_shuffle(tf_inputs, seed=42) tf_labels = tf.random_shuffle(tf_labels, seed=42) with tf.variable_scope("mlp"): tf_logits = self._model.inference(tf_inputs, self._keep_prob, self._num_labels) tf_loss = self._model.loss(tf_logits, tf_labels) tf_train_op = self._model.optimize(tf_loss, self._learning_rate) tf_init_op = tf.global_variables_initializer() tf_saver = tf.train.Saver( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="mlp")) session = tf.Session(graph=graph) session.run(tf_init_op) for epoch in range(self._num_epochs): session.run(tf_train_op) # timestamped model file self._latest_checkpoint = self._checkpoint_dir / "model_{:%Y%m%d%H%M%S%f}".format( datetime.datetime.now()) tf_saver.save(session, str(self._latest_checkpoint), write_meta_graph=False) session.close()
def color_jitter_rand(image, brightness=0, contrast=0, saturation=0, hue=0): """Distorts the color of the image (jittering order is random). Args: image: The input image tensor. brightness: A float, specifying the brightness for color jitter. contrast: A float, specifying the contrast for color jitter. saturation: A float, specifying the saturation for color jitter. hue: A float, specifying the hue for color jitter. Returns: The distorted image tensor. """ with tf.name_scope('distort_color'): def apply_transform(i, x): """Apply the i-th transformation.""" def brightness_foo(): if brightness == 0: return x else: return tf.image.random_brightness(x, max_delta=brightness) def contrast_foo(): if contrast == 0: return x else: return tf.image.random_contrast(x, lower=1 - contrast, upper=1 + contrast) def saturation_foo(): if saturation == 0: return x else: return tf.image.random_saturation(x, lower=1 - saturation, upper=1 + saturation) def hue_foo(): if hue == 0: return x else: return tf.image.random_hue(x, max_delta=hue) x = tf.cond( tf.less(i, 2), lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo), lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo)) return x perm = tf.random_shuffle(tf.range(4)) for i in range(4): image = apply_transform(perm[i], image) image = tf.clip_by_value(image, 0., 1.) return image
def crop_extra_msa(protein, max_extra_msa): """MSA features are cropped so only `max_extra_msa` sequences are kept.""" num_seq = tf.shape(protein['extra_msa'])[0] num_sel = tf.minimum(max_extra_msa, num_seq) select_indices = tf.random_shuffle(tf.range(0, num_seq))[:num_sel] for k in _MSA_FEATURE_NAMES: if 'extra_' + k in protein: protein['extra_' + k] = tf.gather(protein['extra_' + k], select_indices) return protein
def generate_curves(self): num_context = tf.random_uniform( shape=[], minval=3, maxval=self._max_num_context, dtype=tf.int32) num_context = self._max_num_context # If we are testing we want to have more targets and have them evenly # distributed in order to plot the function. if self._testing: num_context = self._max_num_context * 2 num_target = 400 num_total_points = num_target x_values = tf.tile( tf.expand_dims(tf.range(-2., 2., 1. / 100, dtype=tf.float32), axis=0), [self._batch_size, 1]) x_values = tf.expand_dims(x_values, axis=-1) # During training the number of target points and their x-positions are # selected at random else: num_target = tf.random_uniform( shape=(), minval=2, maxval=self._max_num_context, dtype=tf.int32) num_target = self._max_num_context num_total_points = num_context + num_target x_values = tf.random_uniform( [self._batch_size, num_total_points, self._x_size], -2, 2) y_values = self.get_y_values(x_values, num_total_points) if self._testing: # Select the targets target_x = x_values target_y = y_values # Select the observations idx = tf.random_shuffle(tf.range(num_target)) context_x = tf.gather(x_values, idx[:num_context], axis=1) context_y = tf.gather(y_values, idx[:num_context], axis=1) else: # Select the targets which will consist of the context points as well as # some new target points target_x = x_values[:, :num_target + num_context, :] target_y = y_values[:, :num_target + num_context, :] # Select the observations context_x = x_values[:, :num_context, :] context_y = y_values[:, :num_context, :] query = ((context_x, context_y), target_x) return CNPRegressionDescription( query=query, target_y=target_y, num_total_points=tf.shape(target_x)[1], num_context_points=num_context)
def shuffle_codes(z): """Shuffles latent variables across the batch. Args: z: [batch_size, num_latent] representation. Returns: shuffled: [batch_size, num_latent] shuffled representation across the batch. """ z_shuffle = [] for i in range(z.get_shape()[1]): z_shuffle.append(tf.random_shuffle(z[:, i])) shuffled = tf.stack(z_shuffle, 1, name="latent_shuffled") return shuffled
def crop_proposal(): def rand_vec(minval, maxval): return tf.random_uniform( shape=(constants.NUM_CROP_PASSES, 1), minval=minval, maxval=maxval, dtype=tf.float32) width, height = rand_vec(0.3, 1), rand_vec(0.3, 1) left, top = rand_vec(0, 1-width), rand_vec(0, 1-height) right = left + width bottom = top + height ltrb = tf.concat([left, top, right, bottom], axis=1) min_iou = tf.random_shuffle(constants.CROP_MIN_IOU_CHOICES)[0] ious = calc_iou_tensor(ltrb, boxes) # discard any bboxes whose center not in the cropped image xc, yc = [tf.tile(0.5 * (boxes[:, i + 0] + boxes[:, i + 2])[tf.newaxis, :], (constants.NUM_CROP_PASSES, 1)) for i in range(2)] masks = tf.reduce_all(tf.stack([ tf.greater(xc, tf.tile(left, (1, num_boxes))), tf.less(xc, tf.tile(right, (1, num_boxes))), tf.greater(yc, tf.tile(top, (1, num_boxes))), tf.less(yc, tf.tile(bottom, (1, num_boxes))), ], axis=2), axis=2) # Checks of whether a crop is valid. valid_aspect = tf.logical_and(tf.less(height/width, 2), tf.less(width/height, 2)) valid_ious = tf.reduce_all(tf.greater( ious, min_iou), axis=1, keepdims=True) valid_masks = tf.reduce_any(masks, axis=1, keepdims=True) valid_all = tf.cast(tf.reduce_all(tf.concat( [valid_aspect, valid_ious, valid_masks], axis=1), axis=1), tf.int32) # One indexed, as zero is needed for the case of no matches. index = tf.range(1, 1 + constants.NUM_CROP_PASSES, dtype=tf.int32) # Either one-hot, or zeros if there is no valid crop. selection = tf.equal(tf.reduce_max(index * valid_all), index) use_crop = tf.reduce_any(selection) output_ltrb = tf.reduce_sum(tf.multiply(ltrb, tf.tile(tf.cast( selection, tf.float32)[:, tf.newaxis], (1, 4))), axis=0) output_masks = tf.reduce_any(tf.logical_and(masks, tf.tile( selection[:, tf.newaxis], (1, num_boxes))), axis=0) return use_crop, output_ltrb, output_masks
def knn_random(adj_matrix, max_k=40, k=20): """Get KNN based on the pairwise distance. Args: pairwise distance: (batch_size, num_points, num_points) k: int Returns: nearest neighbors: (batch_size, num_points, k) """ neg_adj = -adj_matrix _, nn_idx = tf.nn.top_k(neg_adj, k=max_k + 1) nn_idx = nn_idx[..., 1:max_k + 1] indices = tf.random_shuffle(tf.range(max_k)) indices, _ = tf.nn.top_k(indices[0:k], k=k) return tf.gather(nn_idx, indices, axis=-1)
def disable_some_bgs(): # Mutatis mutandis, all comments from disable_some_fgs apply. shuffled_inds = tf.random_shuffle(bg_inds, seed=self._seed) disable_place = (tf.shape(bg_inds)[0] - max_bg) integrity_assertion = tf.assert_non_negative( disable_place, message="disable_place in disable_some_bgs is negative.") with tf.control_dependencies([integrity_assertion]): disable_inds = shuffled_inds[:disable_place] is_disabled = tf.sparse_to_dense(sparse_indices=disable_inds, sparse_values=True, default_value=False, output_shape=tf.cast( proposals_label_shape, tf.int64), validate_indices=False) return tf.where(condition=is_disabled, x=tf.fill(dims=proposals_label_shape, value=-1.), y=proposals_label)
def subsample_positive(): # Shuffle the foreground indices disable_fg_inds = tf.random_shuffle(fg_inds, seed=self._seed) # Select the indices that we have to ignore, this is # `tf.shape(fg_inds)[0] - num_fg` because we want to get only # `num_fg` foreground labels. disable_place = (tf.shape(fg_inds)[0] - num_fg) disable_fg_inds = disable_fg_inds[:disable_place] # Order the indices for sparse_to_dense compatibility disable_fg_inds, _ = tf.nn.top_k(disable_fg_inds, k=tf.shape(disable_fg_inds)[-1]) disable_fg_inds = tf.reverse(disable_fg_inds, [0]) disable_fg_inds = tf.sparse_to_dense(disable_fg_inds, tf.shape(labels, out_type=tf.int64), True, default_value=False) # Put -1 to ignore the anchors in the selected indices return tf.where(condition=tf.squeeze(disable_fg_inds), x=tf.to_float(tf.fill(tf.shape(labels), -1)), y=labels)
def cal_loss(bn_outputs, seq): losses = [] masks = np.random.choice([0., 1.0], size=batch_size, p=[0.5, 0.5]) weight = tf.random_shuffle(tf.cast(masks, tf.float32)) for i, output in enumerate(bn_outputs): if i % 2 == 0: losses.append( self.sampled_loss(output, seq[:, i + 1], self._rel_w, self._rel_b, weight=weight, is_entity=i)) else: losses.append( self.sampled_loss(output, seq[:, i + 1], self._ent_w, self._ent_b, weight=weight, is_entity=i)) losses = tf.stack(losses, axis=1) return losses
def _face_model_map_fn(example): vertices = example['vertices'] # Randomly shift vertices if apply_random_shift: vertices = random_shift(vertices) example['num_vertices'] = tf.shape(vertices)[0] # Optionally shuffle vertices and re-order faces to match if shuffle_vertices: permutation = tf.random_shuffle(tf.range(example['num_vertices'])) vertices = tf.gather(vertices, permutation) face_permutation = tf.concat([ tf.constant([0, 1], dtype=tf.int32), tf.argsort(permutation) + 2 ], axis=0) example['faces'] = tf.cast( tf.gather(face_permutation, example['faces']), tf.int64) def _dequantize_verts(verts, n_bits): min_range = -0.5 max_range = 0.5 range_quantize = 2**n_bits - 1 verts = tf.cast(verts, tf.float32) verts = verts * (max_range - min_range) / range_quantize + min_range return verts # Vertices are quantized. So convert to floats for input to face model example['vertices'] = _dequantize_verts(vertices, quantization_bits) example['vertices_mask'] = tf.ones_like(example['vertices'][..., 0], dtype=tf.float32) example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32) return example
def make_data_tensor(self, train=True): if train: folders = self.metatrain_character_folders # number of tasks, not number of meta-iterations. (divide by metabatch size to measure) if FLAGS.expt_number == '14a': folders = folders[:(5 * self.num_classes)] elif FLAGS.expt_number == '14b': folders = folders[:(10 * self.num_classes)] elif FLAGS.expt_number == '14c': folders = folders[:(25 * self.num_classes)] elif FLAGS.expt_number == '14d': folders = folders[:(50 * self.num_classes)] elif FLAGS.expt_number == '14e': folders = folders[:(75 * self.num_classes)] elif FLAGS.expt_number == '14f': folders = folders[:(100 * self.num_classes)] elif FLAGS.expt_number == '14g': folders = folders[:(150 * self.num_classes)] elif FLAGS.expt_number == '14h': folders = folders[:(200 * self.num_classes)] if FLAGS.expt_number == '6' or FLAGS.expt_number == '8': print("Inside expt number 6/8") num_total_batches = 26400 elif FLAGS.expt_number == '6a' or FLAGS.expt_number == '8a': print("Inside expt number 6a/8a") num_total_batches = 1100 else: num_total_batches = 200000 else: folders = self.metaval_character_folders num_total_batches = 600 # make list of files print('Generating filenames') if FLAGS.expt_number == '2' and train: print('Inside expt number 2') all_filenames = [] """ go over the folders once, group the adjacent 5 classes together as one task. Non-exclusive """ for task_count in range(int(len(folders) / self.num_classes)): #sampled_character_folders = random.sample(folders, self.num_classes) sampled_character_folders = folders[task_count * self.num_classes: (task_count + 1) * self.num_classes] random.shuffle(sampled_character_folders) labels_and_images = get_images( sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False, train=train) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) elif (FLAGS.expt_number == '9a' or FLAGS.expt_number == '9b' or FLAGS.expt_number == '9c' or FLAGS.expt_number == '11a1' or FLAGS.expt_number == '11a2' or FLAGS.expt_number == '11a3' or ('14' in FLAGS.expt_number)) and train: print('Inside expt number 9a/b/c/11a1/11a2/11a3/14s') all_filenames = [] """ go over the folders multiple times, group the adjacent 5 classes together as one task. Non-exclusive """ if FLAGS.expt_number == '9a' or FLAGS.expt_number == '11a1' or FLAGS.expt_number == '11a2' or FLAGS.expt_number == '11a3' or ( '14' in FLAGS.expt_number): total_num_tasks = 200000 elif FLAGS.expt_number == '9b': total_num_tasks = 26400 elif FLAGS.expt_number == '9c': total_num_tasks = 1100 for _ in range( int(total_num_tasks / (len(folders) / self.num_classes))): #9b #for _ in range(int(num_total_batches/(len(folders)/self.num_classes))): #9a for task_count in range(int(len(folders) / self.num_classes)): #sampled_character_folders = random.sample(folders, self.num_classes) sampled_character_folders = folders[task_count * self.num_classes: (task_count + 1) * self.num_classes] #random.shuffle(sampled_character_folders) labels_and_images = get_images( sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False, train=train) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) elif (FLAGS.expt_number == '7a' or FLAGS.expt_number == '7b' or FLAGS.expt_number == '7c' or FLAGS.expt_number == '11b1' or FLAGS.expt_number == '11b2' or FLAGS.expt_number == '11b3') and train: print('Inside expt number 7a/b/c/11b1/11b2/11b3') all_filenames = [] """ go over the folders multiple times, group the adjacent 5 classes together as one task. Non-exclusive """ if FLAGS.expt_number == '7a' or FLAGS.expt_number == '11b1' or FLAGS.expt_number == '11b2' or FLAGS.expt_number == '11b3': total_num_tasks = 200000 elif FLAGS.expt_number == '7b': total_num_tasks = 26400 elif FLAGS.expt_number == '7c': total_num_tasks = 1100 for _ in range( int(total_num_tasks / (len(folders) / self.num_classes))): #7b #for _ in range(int(num_total_batches/(len(folders)/self.num_classes))): #7a for task_count in range(int(len(folders) / self.num_classes)): #sampled_character_folders = random.sample(folders, self.num_classes) sampled_character_folders = folders[task_count * self.num_classes: (task_count + 1) * self.num_classes] random.shuffle(sampled_character_folders) labels_and_images = get_images( sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False, train=train) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) elif FLAGS.expt_number == '3' and train: print('Inside expt number 3') all_filenames = [] """ removed shuffling of classes in the init function. classes from the same alphabet together now. go over the folders once, group the adjacent 5 classes together as one task. Non-exclusive """ for task_count in range(int(len(folders) / self.num_classes)): #sampled_character_folders = random.sample(folders, self.num_classes) sampled_character_folders = folders[task_count * self.num_classes: (task_count + 1) * self.num_classes] #random.shuffle(sampled_character_folders) #print("Task ", task_count, ": ", sampled_character_folders) labels_and_images = get_images( sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False, train=train) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) elif (FLAGS.expt_number == '4' or FLAGS.expt_number == '5' or FLAGS.expt_number == '8c') and train: print('Inside expt number 4/5/8c') all_filenames = [] """ go over the folders once, group the adjacent 5 classes together as one task. get all permutations of that task. Shuffle these tasks """ task_folders_new = [] for task_count in range(int(len(folders) / self.num_classes)): sampled_character_folders = folders[task_count * self.num_classes: (task_count + 1) * self.num_classes] task_folders_temp = permutations(sampled_character_folders) task_folders_new.extend(task_folders_temp) print('total number of tasks: ', len(task_folders_new)) random.shuffle(task_folders_new) for task_count in range(len(task_folders_new)): #sampled_character_folders = random.sample(folders, self.num_classes) #sampled_character_folders = folders[task_count*self.num_classes: (task_count+1)*self.num_classes] sampled_character_folders = task_folders_new[task_count] #random.shuffle(sampled_character_folders) #print("Task ", task_count, ": ", sampled_character_folders) labels_and_images = get_images( sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False, train=train) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) elif FLAGS.expt_number == '4a' and train: print('Inside expt number 4a') all_filenames = [] """ go over all tasks group the same 5 classes together, but shuffle them everytime. Do it till you reach 1100 tasks """ i = 0 while i < 1100: for task_count in range(int(len(folders) / self.num_classes)): #sampled_character_folders = random.sample(folders, self.num_classes) sampled_character_folders = folders[task_count * self.num_classes: (task_count + 1) * self.num_classes] random.shuffle(sampled_character_folders) #print("Task ", task_count, ": ", sampled_character_folders) labels_and_images = get_images( sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False, train=train) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) i += 1 if i >= 1100: break elif (FLAGS.expt_number == '12a' or FLAGS.expt_number == '12b' or FLAGS.expt_number == '12c' or FLAGS.expt_number == '12d' or FLAGS.expt_number == '12e' or FLAGS.expt_number == '12f') and train: print('Inside expt number 12s') all_filenames = [] """ make tasks first. Then go over the tasks mulitiple times to collect data """ if FLAGS.expt_number == '12a': n_tasks = 220 elif FLAGS.expt_number == '12b': n_tasks = 1100 elif FLAGS.expt_number == '12c': n_tasks = 5500 elif FLAGS.expt_number == '12d': n_tasks = 26400 elif FLAGS.expt_number == '12e': n_tasks = 27500 elif FLAGS.expt_number == '12f': n_tasks = 137500 task_folders_new = [] for task_count in range(n_tasks): sampled_character_folders = random.sample( folders, self.num_classes) #task_folders_temp = permutations(sampled_character_folders) task_folders_new.append(sampled_character_folders) print('total number of tasks: ', len(task_folders_new)) #random.shuffle(task_folders_new) for task_count in range(num_total_batches): sampled_character_folders = task_folders_new[task_count % n_tasks] labels_and_images = get_images( sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False, train=train) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) else: all_filenames = [] print('Inside expt number 1/6/6a/8/8a/11c1/11c2/11c3') for _ in range(num_total_batches): sampled_character_folders = random.sample( folders, self.num_classes) random.shuffle(sampled_character_folders) labels_and_images = get_images( sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False, train=train) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) # make queue for tensorflow to read from filename_queue = tf.train.string_input_producer( tf.convert_to_tensor(all_filenames), shuffle=False) print('Generating image processing ops') image_reader = tf.WholeFileReader() _, image_file = image_reader.read(filename_queue) if FLAGS.datasource == 'miniimagenet': image = tf.image.decode_jpeg(image_file, channels=3) image.set_shape((self.img_size[0], self.img_size[1], 3)) image = tf.reshape(image, [self.dim_input]) image = tf.cast(image, tf.float32) / 255.0 else: image = tf.image.decode_png(image_file) image.set_shape((self.img_size[0], self.img_size[1], 1)) image = tf.reshape(image, [self.dim_input]) image = tf.cast(image, tf.float32) / 255.0 image = 1.0 - image # invert num_preprocess_threads = 1 # TODO - enable this to be set to >1 min_queue_examples = 256 examples_per_batch = self.num_classes * self.num_samples_per_class batch_image_size = self.batch_size * examples_per_batch print('Batching images') images = tf.train.batch( [image], batch_size=batch_image_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_image_size, ) all_image_batches, all_label_batches = [], [] print('Manipulating image data to be right shape') for i in range(self.batch_size): image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch] if FLAGS.datasource == 'omniglot': # omniglot augments the dataset by rotating digits to create new classes # get rotation per class (e.g. 0,1,2,0,0 if there are 5 classes) rotations = tf.multinomial(tf.log([[1., 1., 1., 1.]]), self.num_classes) label_batch = tf.convert_to_tensor(labels) new_list, new_label_list = [], [] for k in range(self.num_samples_per_class): class_idxs = tf.range(0, self.num_classes) class_idxs = tf.random_shuffle(class_idxs) true_idxs = class_idxs * self.num_samples_per_class + k new_list.append(tf.gather(image_batch, true_idxs)) if FLAGS.datasource == 'omniglot': # and FLAGS.train: new_list[-1] = tf.stack([ tf.reshape( tf.image.rot90(tf.reshape( new_list[-1][ind], [self.img_size[0], self.img_size[1], 1]), k=tf.cast( rotations[0, class_idxs[ind]], tf.int32)), (self.dim_input, )) for ind in range(self.num_classes) ]) new_label_list.append(tf.gather(label_batch, true_idxs)) new_list = tf.concat( new_list, 0 ) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input] new_label_list = tf.concat(new_label_list, 0) all_image_batches.append(new_list) all_label_batches.append(new_label_list) all_image_batches = tf.stack(all_image_batches) all_label_batches = tf.stack(all_label_batches) all_label_batches = tf.one_hot(all_label_batches, self.num_classes) return all_image_batches, all_label_batches
def _update_mask(self, weights, threshold, gradients): # pylint: disable=unused-argument """Updates the mask for a given weight tensor. This functions first computes the cdf of the weight tensor, and estimates the threshold value such that 'desired_sparsity' fraction of weights have magnitude less than the threshold. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold gradients: The gradient tensor that is used for salience calculation. Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if sparsity is not defined """ if self._sparsity is None: raise ValueError('Sparsity variable undefined') sparsity = self._get_sparsity(weights.op.name) with tf.name_scope(weights.op.name + '_pruning_ops'): tf.logging.info('Applying option %s pruning', self._spec.prune_option) if self._spec.prune_option == 'weight': abs_weights = tf.abs(weights) elif self._spec.prune_option in ('first_order_gradient', 'second_order_gradient'): if gradients is None: raise ValueError('gradient tensor cannot be None.') # gradient variable stores absolute value already abs_weights = tf.multiply(tf.abs(weights), gradients) else: raise ValueError('undefined option') k = tf.cast( tf.round( tf.cast(tf.size(abs_weights), tf.float32) * (1 - sparsity)), tf.int32) # Generate a random shuffling of the weights s.t. the tie-breaker on # weight magnitude is random uniform. shuffling = tf.random_shuffle(tf.range(tf.size(abs_weights))) shuffling = tf.reshape(shuffling, [-1, 1]) # Flatten the weights and scatter the values randomly. abs_weights = tf.reshape(abs_weights, [-1]) abs_weights = tf.scatter_nd(shuffling, abs_weights, tf.shape(abs_weights)) # Sort the entire array _, indices = tf.nn.top_k(abs_weights, k=tf.size(abs_weights)) # `k` is how many non-zero weights we're going to have. Create a new # mask where the first `k` elements are set to one and all others are # set to zero. mask_staging = tf.range(tf.size(abs_weights)) mask_staging = tf.cast(tf.less(mask_staging, k), tf.float32) # Scatter the mask back into the proper positions for the weight matrix. indices = tf.reshape(indices, [-1, 1]) new_mask = tf.scatter_nd(indices, mask_staging, tf.shape(mask_staging)) # Un-shuffle the newly created mask. new_mask = tf.reshape(tf.gather_nd(new_mask, shuffling), tf.shape(weights)) return tf.constant(0, tf.float32), new_mask
def color_jitter_rand(image, brightness=0, contrast=0, saturation=0, hue=0): """Distorts the color of the image (jittering order is random). Args: image: The input image tensor. brightness: A float, specifying the brightness for color jitter. contrast: A float, specifying the contrast for color jitter. saturation: A float, specifying the saturation for color jitter. hue: A float, specifying the hue for color jitter. Returns: The distorted image tensor. """ with tf.name_scope('distort_color'): def apply_transform(i, x): """Apply the i-th transformation.""" def brightness_foo(): if brightness == 0: return x, tf.constant(0, dtype=tf.float32) else: brightness_factor = tf.random_uniform( [], tf.maximum(1.0 - brightness, 0), 1.0 + brightness) return x * brightness_factor, brightness_factor # return random_brightness(x, max_delta=brightness) def contrast_foo(): if contrast == 0: return x, tf.constant(0, dtype=tf.float32) else: contrast_factor = tf.random_uniform([], minval=1 - contrast, maxval=1 + contrast, dtype=tf.float32) return tf.image.adjust_contrast( x, contrast_factor), contrast_factor # return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast) def saturation_foo(): if saturation == 0: return x, tf.constant(0, dtype=tf.float32) else: saturation_factor = tf.random_uniform( [], minval=1 - saturation, maxval=1 + saturation, dtype=tf.float32) return tf.image.adjust_saturation( x, saturation_factor), saturation_factor # return tf.image.random_saturation( # x, lower=1-saturation, upper=1+saturation) def hue_foo(): if hue == 0: return x, tf.constant(0, dtype=tf.float32) else: hue_factor = tf.random_uniform([], -hue, hue) return tf.image.adjust_hue(x, delta=hue_factor), hue_factor # return tf.image.random_hue(x, max_delta=hue) x = tf.cond( tf.less(i, 2), lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo), lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo)) return x perm = tf.random_shuffle(tf.range(4)) theta_color = [] for i in range(4): image, factor = apply_transform(perm[i], image) image = tf.clip_by_value(image, 0., 1.) theta_color.append(factor) theta_color = tf.cast(tf.stack(theta_color), tf.float32) theta_color = tf.gather(theta_color, perm) return image, theta_color
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): """Samples a permutation of the factorization order, and create a mask. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. Returns: The permutation mask, new targets, target mask, and new inputs. """ # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.transpose(tf.reshape(index, [-1, perm_size])) index = tf.random_shuffle(index) index = tf.reshape(tf.transpose(index), [-1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not( tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won"t be information leak smallest_index = -tf.ones([seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and(self_rev_index[:, None] <= rev_index[None, :], masked_or_func_tokens) perm_mask = tf.cast(perm_mask, tf.float32) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[0:1], targets[:-1]], axis=0) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def make_data_tensor(self, train=True): if train: folders = self.metatrain_character_folders # number of tasks, not number of meta-iterations. (divide by metabatch size to measure) num_total_batches = 200000 if FLAGS.task_type == "ne": print("Inside ne train") folders = self.metatrain_character_folders[:50] # 10 classification tasks num_total_batches = 5000 else: folders = self.metaval_character_folders num_total_batches = 600 if FLAGS.task_type == "ne": print("inside ne val") if FLAGS.test_set: folders = self.metaval_character_folders[:15] else: folders = self.metaval_character_folders[:15] # 3 classification tasks num_total_batches = 60 # print("folders", folders) # make list of files print('Generating filenames') if FLAGS.task_setting == 'ne' and train: # all_filenames = [] # random.shuffle(folders) # for i in range(len(folders)/self.num_classes): # #sampled_character_folders = random.sample(folders, self.num_classes) # #random.shuffle(sampled_character_folders) # sampled_character_folders = folders[i*self.num_classes:(i+1)*self.num_classes] # labels_and_images = get_images(sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False) # # make sure the above isn't randomized order # labels = [li[0] for li in labels_and_images] # filenames = [li[1] for li in labels_and_images] # all_filenames.extend(filenames) all_filenames = [] print("len folders", len(folders)) random.shuffle(folders) task_folders_new = [] for i in range(int(len(folders)/self.num_classes)): # sampled_character_folders = random.sample(folders, self.num_classes) # random.shuffle(sampled_character_folders) sampled_character_folders = folders[i*self.num_classes:(i+1)*self.num_classes] task_folders_temp = itertools.permutations(sampled_character_folders) task_folders_new.extend(task_folders_temp) # print("task_folders_new", task_folders_new) print("len of task_folders_new", len(task_folders_new)) random.shuffle(task_folders_new) for i in range(len(task_folders_new)): sampled_character_folders = task_folders_new[i] # print("scf", sampled_character_folders) labels_and_images = get_images(sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) else: all_filenames = [] for _ in range(num_total_batches): sampled_character_folders = random.sample(folders, self.num_classes) random.shuffle(sampled_character_folders) labels_and_images = get_images(sampled_character_folders, range(self.num_classes), nb_samples=self.num_samples_per_class, shuffle=False) # make sure the above isn't randomized order labels = [li[0] for li in labels_and_images] filenames = [li[1] for li in labels_and_images] all_filenames.extend(filenames) # make queue for tensorflow to read from filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False) print('Generating image processing ops') image_reader = tf.WholeFileReader() _, image_file = image_reader.read(filename_queue) if FLAGS.datasource == 'miniimagenet': image = tf.image.decode_jpeg(image_file, channels=3) image.set_shape((self.img_size[0],self.img_size[1],3)) image = tf.reshape(image, [self.dim_input]) image = tf.cast(image, tf.float32) / 255.0 else: image = tf.image.decode_png(image_file) image.set_shape((self.img_size[0],self.img_size[1],1)) image = tf.reshape(image, [self.dim_input]) image = tf.cast(image, tf.float32) / 255.0 image = 1.0 - image # invert num_preprocess_threads = 1 # TODO - enable this to be set to >1 min_queue_examples = 256 examples_per_batch = self.num_classes * self.num_samples_per_class batch_image_size = self.batch_size * examples_per_batch print('Batching images') print("batch_image_size", batch_image_size) images = tf.train.batch( [image], batch_size = batch_image_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_image_size, ) print("len images", images.shape) all_image_batches, all_label_batches = [], [] print('Manipulating image data to be right shape') for i in range(self.batch_size): image_batch = images[i*examples_per_batch:(i+1)*examples_per_batch] if FLAGS.datasource == 'omniglot': # omniglot augments the dataset by rotating digits to create new classes # get rotation per class (e.g. 0,1,2,0,0 if there are 5 classes) rotations = tf.multinomial(tf.log([[1., 1., 1., 1.]]), self.num_classes) # print("labels", labels) label_batch = tf.convert_to_tensor(labels) new_list, new_label_list = [], [] # shuffles the data within a batch, class labels remain fixed for k in range(self.num_samples_per_class): class_idxs = tf.range(0, self.num_classes) class_idxs = tf.random_shuffle(class_idxs) true_idxs = class_idxs * self.num_samples_per_class + k new_list.append(tf.gather(image_batch,true_idxs)) if FLAGS.datasource == 'omniglot': # and FLAGS.train: new_list[-1] = tf.stack([tf.reshape(tf.image.rot90( tf.reshape(new_list[-1][ind], [self.img_size[0], self.img_size[1],1]), k=tf.cast(rotations[0, class_idxs[ind]], tf.int32)), (self.dim_input,)) for ind in range(self.num_classes)]) new_label_list.append(tf.gather(label_batch, true_idxs)) new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input] new_label_list = tf.concat(new_label_list, 0) all_image_batches.append(new_list) all_label_batches.append(new_label_list) all_image_batches = tf.stack(all_image_batches) all_label_batches = tf.stack(all_label_batches) print("all_image_batches", all_image_batches) all_label_batches = tf.one_hot(all_label_batches, self.num_classes) return all_image_batches, all_label_batches
def _static_subsample(self, indicator, batch_size, labels): """Returns subsampled minibatch. Args: indicator: boolean tensor of shape [N] whose True entries can be sampled. N should be a complie time constant. batch_size: desired batch size. This scalar cannot be None. labels: boolean tensor of shape [N] denoting positive(=True) and negative (=False) examples. N should be a complie time constant. Returns: sampled_idx_indicator: boolean tensor of shape [N], True for entries which are sampled. It ensures the length of output of the subsample is always batch_size, even when number of examples set to True in indicator is less than batch_size. Raises: ValueError: if labels and indicator are not 1D boolean tensors. """ # Check if indicator and labels have a static size. if not indicator.shape.is_fully_defined(): raise ValueError( 'indicator must be static in shape when is_static is' 'True') if not labels.shape.is_fully_defined(): raise ValueError('labels must be static in shape when is_static is' 'True') if not isinstance(batch_size, int): raise ValueError( 'batch_size has to be an integer when is_static is' 'True.') input_length = tf.shape(indicator)[0] # Set the number of examples set True in indicator to be at least # batch_size. num_true_sampled = tf.reduce_sum(tf.cast(indicator, tf.float32)) additional_false_sample = tf.less_equal( tf.cumsum(tf.cast(tf.logical_not(indicator), tf.float32)), batch_size - num_true_sampled) indicator = tf.logical_or(indicator, additional_false_sample) # Shuffle indicator and label. Need to store the permutation to restore the # order post sampling. permutation = tf.random_shuffle(tf.range(input_length)) indicator = ops.matmul_gather_on_zeroth_axis( tf.cast(indicator, tf.float32), permutation) labels = ops.matmul_gather_on_zeroth_axis(tf.cast(labels, tf.float32), permutation) # index (starting from 1) when indicator is True, 0 when False indicator_idx = tf.where(tf.cast(indicator, tf.bool), tf.range(1, input_length + 1), tf.zeros(input_length, tf.int32)) # Replace -1 for negative, +1 for positive labels signed_label = tf.where( tf.cast(labels, tf.bool), tf.ones(input_length, tf.int32), tf.scalar_mul(-1, tf.ones(input_length, tf.int32))) # negative of index for negative label, positive index for positive label, # 0 when indicator is False. signed_indicator_idx = tf.multiply(indicator_idx, signed_label) sorted_signed_indicator_idx = tf.nn.top_k(signed_indicator_idx, input_length, sorted=True).values [num_positive_samples, num_negative_samples ] = self._get_num_pos_neg_samples(sorted_signed_indicator_idx, batch_size) sampled_idx = self._get_values_from_start_and_end( sorted_signed_indicator_idx, num_positive_samples, num_negative_samples, batch_size) # Shift the indices to start from 0 and remove any samples that are set as # False. sampled_idx = tf.abs(sampled_idx) - tf.ones(batch_size, tf.int32) sampled_idx = tf.multiply( tf.cast(tf.greater_equal(sampled_idx, tf.constant(0)), tf.int32), sampled_idx) sampled_idx_indicator = tf.cast( tf.reduce_sum(tf.one_hot(sampled_idx, depth=input_length), axis=0), tf.bool) # project back the order based on stored permutations reprojections = tf.one_hot(permutation, depth=input_length, dtype=tf.float32) return tf.cast( tf.tensordot(tf.cast(sampled_idx_indicator, tf.float32), reprojections, axes=[0, 0]), tf.bool)
def get_doc_rep_with_masked_sent(input_sent_reps_doc, sent_mask_embedding, input_mask_doc_level, batch_size_static=32, max_masked_sent_per_doc=2, loop_sent_number_per_doc=32): """Get the document representations with masked sentences. Args: input_sent_reps_doc: float Tensor. The independent sentence embeddings without masks for the sentences in the current document. The shape is [batch, loop_sent_number_per_doc, hidden]. sent_mask_embedding: float Tensor. The sentence embedding vector for the masked position. The shape is [hidden]. input_mask_doc_level: int Tensor. The input masks on the document level to identify whether a location is a real sentence (mask = 1) or a padded sentence (mask = 0). The shape is [batch, loop_sent_number_per_doc]. batch_size_static: scalar. The static batch size depending on the training or the evaluation mode. max_masked_sent_per_doc: scalar. The maximum number of masked sentences per document. loop_sent_number_per_doc: scalar. The number of looped sentences per document. Returns: The document representations with masked sentences and the positions/ weights for each masked sentences. This masked sentence weight is 1 for the sampled real sentence position and 0 for the padded sentence position. """ # We at least mask two sentences to build a candidate sentence pool for # negative sentence sampling. We generate the masked_sent_index and # masked_sent_weight for each document. Note that we do not add any word # or sentence level masks during prediction or inference stage. max_masked_sent_per_doc = max(max_masked_sent_per_doc, 2) input_sent_reps_doc_list = tf.unstack( input_sent_reps_doc, num=batch_size_static) real_sent_number_per_doc = tf.unstack( tf.reduce_sum(input_mask_doc_level, 1), num=batch_size_static) masked_sent_index_list = [] masked_sent_weight_list = [] # For each example in the current batch, we randomly sample # max_masked_sent_per_doc positions to mask the sentences. For each masked # sentence position, the sentence in the current position is the positive # example. The other co-masked sentences are the negative examples. # The sampled sentence indexes will not be duplicated. for batch_i in range(0, batch_size_static): # Since everything in TPU must have a fixed shape, here the max sampled # sentence index can be as large as loop_sent_number_per_doc. We will # generate the corresponding sentence LM weights to reduce the impact # on the final masked sentence LM loss following a similar way with the # handling of masked word LM loss and masked word LM weights. real_sent_number = real_sent_number_per_doc[batch_i] sampled_sent_index = tf.slice( tf.random_shuffle(tf.range(loop_sent_number_per_doc)), [0], [max_masked_sent_per_doc]) sampled_sent_index = tf.sort(sampled_sent_index) masked_sent_index_list.append(sampled_sent_index) # Generates the corresponding sampled_sent_weight sample_sent_weight = tf.cast( tf.less(sampled_sent_index, real_sent_number), tf.float32) masked_sent_weight_list.append(sample_sent_weight) indices = tf.reshape(sampled_sent_index, [max_masked_sent_per_doc, -1]) # Duplicates sent_mask_embedding for each masked position. updates = tf.reshape( tf.tile( sent_mask_embedding, [max_masked_sent_per_doc], ), [max_masked_sent_per_doc, -1]) input_sent_reps_doc_list[batch_i] = tf.tensor_scatter_update( input_sent_reps_doc_list[batch_i], indices, updates) # Here masked_sent_index_list is a list a tensors, where each tensor stores # the masked sentence positions for each document in the current batch. The # shape of masked_sent_index_list is [batch, max_masked_sent_per_doc]. # Here masked_sent_weight_list is a list a tensors, where each tensor stores # the masked sentence weights for each document in the current batch. The # shape of masked_sent_weight_list is [batch, max_masked_sent_per_doc]. return (tf.stack(input_sent_reps_doc_list), tf.stack(masked_sent_index_list), tf.stack(masked_sent_weight_list))