def update_masks(): """check whether all pruning conditions are met before pruning.""" with tf.name_scope(self._spec.name): is_step_within_pruning_range = tf.logical_and( tf.greater_equal(self._global_step, self._spec.begin_pruning_step), # If end_pruning_step is negative, keep pruning forever! tf.logical_or( tf.less_equal(self._global_step, self._spec.end_pruning_step), tf.less(self._spec.end_pruning_step, 0))) is_pruning_step = tf.less_equal( tf.add(self._last_update_step, self._spec.pruning_frequency), self._global_step) return tf.logical_and(is_step_within_pruning_range, is_pruning_step)
def _rpn_score_loss(self, score_outputs, score_targets, normalizer=1.0): """Computes score loss.""" # score_targets has three values: # (1) score_targets[i]=1, the anchor is a positive sample. # (2) score_targets[i]=0, negative. # (3) score_targets[i]=-1, the anchor is don't care (ignore). with tf.name_scope('rpn_score_loss'): mask = tf.logical_or(tf.equal(score_targets, 1), tf.equal(score_targets, 0)) score_targets = tf.maximum(score_targets, tf.zeros_like(score_targets)) # RPN score loss is sum over all except ignored samples. score_loss = tf.losses.sigmoid_cross_entropy( score_targets, score_outputs, weights=mask, reduction=tf.losses.Reduction.SUM) score_loss /= normalizer return score_loss
def while_body(t, z, accept): """Truncated rejection sampling.""" new_z = self.proposal.sample(num_samples) accept_prob = tf.squeeze(tf.exp(self.accept_fn(new_z - self.data_mean)), axis=-1) new_accept = tf.math.less_equal( tf.random_uniform(shape=[num_samples], minval=0., maxval=1.), accept_prob) force_accept = tf.math.greater_equal( tf.to_float(t), tf.to_float(self.T) - 1.) new_accept = tf.math.logical_or(new_accept, force_accept) accepted = tf.logical_or(accept, new_accept) swap = tf.math.logical_and(tf.math.logical_not(accept), new_accept) z = tf.where(swap, new_z, z) return t + 1, z, accepted
def maybe_split_sequence_lengths(sequence_length, num_splits, total_length): """Validates and splits `sequence_length`, if necessary. Returned value must be used in graph for all validations to be executed. Args: sequence_length: A batch of sequence lengths, either sized `[batch_size]` and equal to either 0 or `total_length`, or sized `[batch_size, num_splits]`. num_splits: The scalar number of splits of the full sequences. total_length: The scalar total sequence length (potentially padded). Returns: sequence_length: If input shape was `[batch_size, num_splits]`, returns the same Tensor. Otherwise, returns a Tensor of that shape with each input length in the batch divided by `num_splits`. Raises: ValueError: If `sequence_length` is not shaped `[batch_size]` or `[batch_size, num_splits]`. tf.errors.InvalidArgumentError: If `sequence_length` is shaped `[batch_size]` and all values are not either 0 or `total_length`. """ if sequence_length.shape.ndims == 1: if total_length % num_splits != 0: raise ValueError( '`total_length` must be evenly divisible by `num_splits`.') with tf.control_dependencies([ tf.Assert(tf.reduce_all( tf.logical_or(tf.equal(sequence_length, 0), tf.equal(sequence_length, total_length))), data=[sequence_length]) ]): sequence_length = (tf.tile(tf.expand_dims(sequence_length, axis=1), [1, num_splits]) // num_splits) elif sequence_length.shape.ndims == 2: with tf.control_dependencies([ tf.assert_less_equal( sequence_length, tf.constant(total_length // num_splits, tf.int32), message='Segment length cannot be more than ' '`total_length / num_splits`.') ]): sequence_length = tf.identity(sequence_length) sequence_length.set_shape([sequence_length.shape[0], num_splits]) else: raise ValueError( 'Sequence lengths must be given as a vector or a 2D Tensor whose ' 'second dimension size matches its initial hierarchical split. Got ' 'shape: %s' % sequence_length.shape.as_list()) return sequence_length
def _distorted_crop_window(image_shape, min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.08, 1.0), max_attempts=100): """Computes a sampled distorted crop window from an input image shape. Calls into `tf.image.sample_distorted_bounding_box`, using the entire image as the bounding box. This can theoretically fail, in which case, we fall back to a deterministic center square crop. Args: image_shape: The shape of the image, expressed as a Tensor of shape [3], an iterable of length 3, or a tf.Shape with rank 3. min_object_covered: See `tf.image.sample_distorted_bounding_box`. aspect_ratio_range: See `tf.image.sample_distorted_bounding_box`. area_range: See `tf.image.sample_distorted_bounding_box`. max_attempts: See `tf.image.sample_distorted_bounding_box`. Returns: A Tensor of shape [6], representing the crop box in the format [offset_height, offset_width, offset_channel, crop_dim, crop_dim, channels]. `offset_channel` is always 0. """ with tf.name_scope('distorted_crop_window'): sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( image_shape, bounding_boxes=tf.zeros(shape=[1, 0, 4]), min_object_covered=min_object_covered, aspect_ratio_range=aspect_ratio_range, area_range=area_range, max_attempts=max_attempts, use_image_if_no_bounding_boxes=True) bbox_begin, bbox_size, _ = sample_distorted_bounding_box offset_y, offset_x, _ = tf.unstack(bbox_begin) target_height, target_width, _ = tf.unstack(bbox_size) crop_window_params = [ offset_y, offset_x, 0, target_height, target_width, image_shape[2] ] # sample_distorted_bounding_box can fail, in which case it returns the input # image dimensions. In case of failure, fall back to central crop. success = tf.logical_or(tf.not_equal(target_height, image_shape[0]), tf.not_equal(target_width, image_shape[1])) crop_window = tf.cond( success, lambda: tf.stack(crop_window_params), lambda: _center_crop_window(image_shape, crop_frac=1.)) return crop_window
def build_infer_graph(self, FLAGS, batch_data, bbox=None, name='val'): """ """ if FLAGS.guided: batch_data, edge = batch_data edge = edge[:, :, :, 0:1] / 255. edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32) regular_mask = bbox2mask(FLAGS, bbox, name='mask_c') irregular_mask = brush_stroke_mask(FLAGS, name='mask_c') mask = tf.cast( tf.logical_or( tf.cast(irregular_mask, tf.bool), tf.cast(regular_mask, tf.bool), ), tf.float32 ) batch_pos = batch_data / 127.5 - 1. batch_incomplete = batch_pos*(1.-mask) if FLAGS.guided: edge = edge * mask xin = tf.concat([batch_incomplete, edge], axis=3) else: xin = batch_incomplete # inpaint x1, x2, offset_flow = self.build_inpaint_net( xin, mask, reuse=True, training=False, padding=FLAGS.padding) batch_predicted = x2 # apply mask and reconstruct batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask) # global image visualization if FLAGS.guided: viz_img = [ batch_pos, batch_incomplete + edge, batch_complete] else: viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.compat.v1.image.resize_bilinear)) images_summary( tf.concat(viz_img, axis=2), name+'_raw_incomplete_complete', FLAGS.viz_max_out) return batch_complete
def image_corruption(image, label, reso, image_corrupt_ratio_mean, image_corrupt_ratio_stddev): """Randomly drop non-lesion pixels.""" if image_corrupt_ratio_mean < 0.000001 and (image_corrupt_ratio_stddev < 0.000001): return image # Corrupt non-lesion region according to keep_mask. keep_mask = _gen_rand_mask(1 - image_corrupt_ratio_mean, image_corrupt_ratio_stddev, reso // 3, image.shape) keep_mask = tf.logical_or(tf.greater(label, 1.5), keep_mask) image *= tf.cast(keep_mask, tf.float32) return image
def loss(self, inputs): """L2 loss on velocity.""" graph = self._build_graph(inputs, is_training=True) network_output = self._learned_model(graph) # build target velocity change cur_velocity = inputs['velocity'] target_velocity = inputs['target|velocity'] target_velocity_change = target_velocity - cur_velocity target_normalized = self._output_normalizer(target_velocity_change) # build loss node_type = inputs['node_type'][:, 0] loss_mask = tf.logical_or(tf.equal(node_type, common.NodeType.NORMAL), tf.equal(node_type, common.NodeType.OUTFLOW)) error = tf.reduce_sum((target_normalized - network_output)**2, axis=1) loss = tf.reduce_mean(error[loss_mask]) return loss
def detectMinVal(input_mat, var, threshold=1e-6, name='', debug=False): eigen_min = tf.reduce_min(input_mat) eigen_max = tf.reduce_max(input_mat) eigen_ratio = eigen_max / eigen_min input_mat_clipped = clipoutNeg(input_mat, threshold) if debug: input_mat_clipped = tf.cond( tf.logical_or(tf.greater(eigen_ratio, 0.), tf.less(eigen_ratio, -500)), lambda: input_mat_clipped, lambda: tf.Print(input_mat_clipped, [ tf.convert_to_tensor('screwed ratio ' + name + ' eigen values!!!'), tf.convert_to_tensor(var.name), eigen_min, eigen_max, eigen_ratio ])) return input_mat_clipped
def _rollout(model, initial_state, num_steps): """Rolls out a model trajectory.""" node_type = initial_state['node_type'][:, 0] mask = tf.logical_or(tf.equal(node_type, NodeType.NORMAL), tf.equal(node_type, NodeType.OUTFLOW)) def step_fn(step, velocity, trajectory): prediction = model({**initial_state, 'velocity': velocity}) # don't update boundary nodes next_velocity = tf.where(mask, prediction, velocity) trajectory = trajectory.write(step, velocity) return step + 1, next_velocity, trajectory _, _, output = tf.while_loop( cond=lambda step, cur, traj: tf.less(step, num_steps), body=step_fn, loop_vars=(0, initial_state['velocity'], tf.TensorArray(tf.float32, num_steps)), parallel_iterations=1) return output.stack()
def _sequence_correct(labels: decode_utils.LabelsDict, predictions: decode_utils.PredictionsDict): """Computes a per-example sequence accuracy.""" target_decode_steps = decode_utils.decode_steps_from_labels( labels, trim_start_symbol=True) predicted_decode_steps = decode_utils.decode_steps_from_predictions( predictions) decode_utils.assert_shapes_match(target_decode_steps, predicted_decode_steps) equal_tokens = decode_utils.compare_decode_steps(target_decode_steps, predicted_decode_steps) target_len = labels["target_len"] - 1 loss_mask = tf.sequence_mask(lengths=tf.to_int32(target_len), maxlen=tf.to_int32(tf.shape(equal_tokens)[1])) equal_tokens = tf.logical_or(equal_tokens, tf.logical_not(loss_mask)) all_equal = tf.cast(tf.reduce_all(equal_tokens, 1), tf.float32) return all_equal
def zero_out_clipped_grads(grad, x, clip_min, clip_max): """ Helper function to erase entries in the gradient where the update would be clipped. :param grad: The gradient :param x: The current input :param clip_min: Minimum input component value :param clip_max: Maximum input component value """ signed_grad = tf.sign(grad) # Find input components that lie at the boundary of the input range, and # where the gradient points in the wrong direction. clip_low = tf.logical_and(tf.less_equal(x, tf.cast(clip_min, x.dtype)), tf.less(signed_grad, 0)) clip_high = tf.logical_and(tf.greater_equal(x, tf.cast(clip_max, x.dtype)), tf.greater(signed_grad, 0)) clip = tf.logical_or(clip_low, clip_high) grad = tf.where(clip, mul(grad, 0), grad) return grad
def compare_generating_steps(target_decode_steps, predicted_decode_steps): """Compare generating steps only but ignoring target copying steps. Args: target_decode_steps: Target DecodeSteps, Each tensor is expected to be shape [batch_size, output_length]. predicted_decode_steps: Predicted DecodeSteps, Each tensor is expected to be shape [batch_size, output_length]. Returns: A tensor of bools indicating whether generating steps are equal. Copy Steps will have value True. """ # Set all copying steps to True, Since we only care about generating steps. return tf.logical_or( tf.not_equal(target_decode_steps.action_types, constants.GENERATE_ACTION), tf.logical_and( tf.equal(target_decode_steps.action_types, predicted_decode_steps.action_types), tf.equal(target_decode_steps.action_ids, predicted_decode_steps.action_ids)))
def _update(self, rs, ps): ops = [] # Compute the coefficient alpha. pTHp = tf.zeros(shape=[], dtype=ps[0].dtype) for p, Hz in zip(ps, self._hessians): # Recall that p has already been assigned to z, and hence Hz = Hp. pTHp += tf.reduce_sum(p * Hz) # Compute the coefficient for the update. alpha = self._rTr / pTHp # Create a tensor that computes the norm of the iterate after the update # without actually modifying it. norm_dw_new = tf.zeros(shape=[], dtype=self._norm_dw.dtype) for dw, p in zip(self._dws, ps): dw_new = dw + alpha * p norm_dw_new += tf.reduce_sum(dw_new * dw_new) norm_dw_new = tf.sqrt(norm_dw_new) # Determine if we should follow the direction p until it intersects with the # boundary of the trust region. # This is the case if either p is a direction of indefiniteness or if dw + p # would be outside the trust region. follow_to_boundary = tf.logical_or(pTHp <= 0.0, norm_dw_new > self._radius_placeh) self._follow_to_boundary = tf.Variable(False) ops.append(tf.assign(self._follow_to_boundary, follow_to_boundary)) # If we follow p up to the boundary, we do not update dw here. # Instead, we determine the final update dw in the 'solve' method. alpha_or_zero = tf.cond(follow_to_boundary, lambda: 0.0, lambda: alpha) # Update the solution and residual. for dw, r, p, Hz in zip(self._dws, rs, ps, self._hessians): ops.append(tf.assign_add(dw, alpha_or_zero * p)) ops.append(tf.assign_sub(r, alpha_or_zero * Hz)) return tf.group(ops)
def maybe_update_alpha(): """Operator to update alpha. Checks if global_step is between begin_compression_step and end_compression_step. """ with tf.compat.v1.name_scope(self._spec.name): # prune if current step is more than begin_compression_step and # less than end_compression_step (unless it's negative) is_step_within_compression_range = tf.logical_and( tf.greater_equal(tf.cast(self._global_step, tf.int32), self._spec.begin_compression_step), tf.logical_or( tf.less_equal(tf.cast(self._global_step, tf.int32), self._spec.end_compression_step), tf.less(self._spec.end_compression_step, 0))) is_compression_step = tf.less_equal( tf.add(self._last_alpha_update_step, self._spec.compression_frequency), tf.cast(self._global_step, tf.int32)) return tf.logical_and(is_step_within_compression_range, is_compression_step)
def _online_sample_masks(inputs, tgt_len, num_predict, boundary=None, stride=1): """Sample target positions to predict.""" tf.logging.info("Online sample with strategy: `%s`.", FLAGS.sample_strategy) if FLAGS.sample_strategy == "single_token": return _single_token_mask(inputs, tgt_len, num_predict) else: if FLAGS.sample_strategy == "whole_word": assert boundary is not None, "whole word sampling requires `boundary`" is_target, target_mask = _whole_word_mask(inputs, tgt_len, num_predict, boundary) elif FLAGS.sample_strategy == "token_span": is_target, target_mask = _token_span_mask(inputs, tgt_len, num_predict, stride=stride) elif FLAGS.sample_strategy == "word_span": assert boundary is not None, "word span sampling requires `boundary`" is_target, target_mask = _word_span_mask(inputs, tgt_len, num_predict, boundary, stride=stride) else: raise NotImplementedError # Fill in single tokens if not full cur_num_masked = tf.reduce_sum(tf.cast(is_target, tf.int64)) extra_mask, extra_tgt_mask = _single_token_mask( inputs, tgt_len, num_predict - cur_num_masked, is_target) return tf.logical_or(is_target, extra_mask), target_mask + extra_tgt_mask
def _build(self, x, state): prev_keep_mask = state shape = tf.shape(x) noise = tf.random_uniform(shape, dtype=x.dtype) other_mask = tf.floor(self._keep_prob + noise) choice_noise = tf.random_uniform(shape, dtype=x.dtype) choice = tf.less(choice_noise, self._flip_prob) # KLUDGE(melisgl): The client has to pass the last keep_mask from # a batch to the next so the mask may end up next to some # recurrent cell state. This state is often zero at the beginning # and may be periodically zeroed (per example) during training. # While zeroing LSTM state is okay, zeroing the dropout mask is # not. So instead of forcing every client to deal with this common # (?) case, if an all zero mask is detected, then regenerate a # fresh mask. This is of course a major hack and won't help with # learnt initial states, for example. sum_ = tf.reduce_sum(prev_keep_mask, 1, keepdims=True) is_initializing = tf.equal(sum_, 0.0) self._keep_mask = tf.where(tf.logical_or(choice, is_initializing), other_mask, prev_keep_mask) self._time_step += 1 return x * self._keep_mask / self._keep_prob * self._scaler
def _define_collect(batch_env, ppo_hparams, scope, frame_stack_size, eval_phase, sampling_temp, force_beginning_resets, distributional_size=1): """Collect trajectories. Args: batch_env: Batch environment. ppo_hparams: PPO hparams, defined in tensor2tensor.models.research.rl. scope: var scope. frame_stack_size: Number of last observations to feed into the policy. eval_phase: TODO(koz4k): Write docstring. sampling_temp: Sampling temperature for the policy. force_beginning_resets: Whether to reset at the beginning of each episode. distributional_size: optional, number of buckets in distributional RL. Returns: Returns memory (observations, rewards, dones, actions, pdfs, values_functions) containing a rollout of environment from nested wrapped structure. """ epoch_length = ppo_hparams.epoch_length to_initialize = [] with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): num_agents = batch_env.batch_size to_initialize.append(batch_env) wrappers = [(StackWrapper, { "history": frame_stack_size }), (_MemoryWrapper, {})] rollout_metadata = None speculum = None for w in wrappers: tf.logging.info("Applying wrapper %s(%s) to env %s." % (str(w[0]), str(w[1]), str(batch_env))) batch_env = w[0](batch_env, **w[1]) to_initialize.append(batch_env) rollout_metadata = _rollout_metadata(batch_env, distributional_size) speculum = batch_env.speculum def initialization_lambda(sess): for batch_env in to_initialize: batch_env.initialize(sess) memory = [ tf.get_variable( # pylint: disable=g-complex-comprehension "collect_memory_%d_%s" % (epoch_length, name), shape=[epoch_length] + shape, dtype=dtype, initializer=tf.zeros_initializer(), trainable=False) for (shape, dtype, name) in rollout_metadata ] cumulative_rewards = tf.get_variable("cumulative_rewards", len(batch_env), trainable=False) eval_phase_t = tf.convert_to_tensor(eval_phase) should_reset_var = tf.Variable(True, trainable=False) zeros_tensor = tf.zeros(len(batch_env)) force_beginning_resets = tf.convert_to_tensor(force_beginning_resets) def reset_ops_group(): return tf.group(batch_env.reset(tf.range(len(batch_env))), tf.assign(cumulative_rewards, zeros_tensor)) reset_op = tf.cond( tf.logical_or(should_reset_var.read_value(), force_beginning_resets), reset_ops_group, tf.no_op) with tf.control_dependencies([reset_op]): reset_once_op = tf.assign(should_reset_var, False) with tf.control_dependencies([reset_once_op]): def step(index, scores_sum, scores_num): """Single step.""" index %= epoch_length # Only needed in eval runs. # Note - the only way to ensure making a copy of tensor is to run simple # operation. We are waiting for tf.copy: # https://github.com/tensorflow/tensorflow/issues/11186 obs_copy = batch_env.observ + 0 value_fun_shape = (num_agents, ) if distributional_size > 1: value_fun_shape = (num_agents, distributional_size) def env_step(arg1, arg2, arg3): # pylint: disable=unused-argument """Step of the environment.""" (logits, value_function) = get_policy(obs_copy, ppo_hparams, batch_env.action_space, distributional_size) action = common_layers.sample_with_temperature( logits, sampling_temp) action = tf.cast(action, tf.int32) action = tf.reshape(action, shape=(num_agents, )) reward, done = batch_env.simulate(action) pdf = tfp.distributions.Categorical(logits=logits).prob(action) pdf = tf.reshape(pdf, shape=(num_agents, )) value_function = tf.reshape(value_function, shape=value_fun_shape) done = tf.reshape(done, shape=(num_agents, )) with tf.control_dependencies([reward, done]): return tf.identity(pdf), tf.identity(value_function), \ tf.identity(done) # TODO(piotrmilos): while_body is executed at most once, # thus should be replaced with tf.cond pdf, value_function, top_level_done = tf.while_loop( lambda _1, _2, _3: tf.equal(speculum.size(), 0), env_step, [ tf.constant(0.0, shape=(num_agents, )), tf.constant(0.0, shape=value_fun_shape), tf.constant(False, shape=(num_agents, )) ], parallel_iterations=1, back_prop=False, ) with tf.control_dependencies([pdf, value_function]): obs, reward, done, action = speculum.dequeue() to_save = [obs, reward, done, action, pdf, value_function] save_ops = [ tf.scatter_update(memory_slot, index, value) for memory_slot, value in zip(memory, to_save) ] cumulate_rewards_op = cumulative_rewards.assign_add(reward) agent_indices_to_reset = tf.where(top_level_done)[:, 0] with tf.control_dependencies([cumulate_rewards_op]): # TODO(piotrmilos): possibly we need cumulative_rewards.read_value() scores_sum_delta = tf.reduce_sum( tf.gather(cumulative_rewards.read_value(), agent_indices_to_reset)) scores_num_delta = tf.count_nonzero(done, dtype=tf.int32) with tf.control_dependencies(save_ops + [scores_sum_delta, scores_num_delta]): reset_env_op = batch_env.reset(agent_indices_to_reset) reset_cumulative_rewards_op = tf.scatter_update( cumulative_rewards, agent_indices_to_reset, tf.gather(zeros_tensor, agent_indices_to_reset)) with tf.control_dependencies( [reset_env_op, reset_cumulative_rewards_op]): return [ index + 1, scores_sum + scores_sum_delta, scores_num + scores_num_delta ] def stop_condition(i, _, resets): return tf.cond(eval_phase_t, lambda: resets < num_agents, lambda: i < epoch_length) init = [tf.constant(0), tf.constant(0.0), tf.constant(0)] index, scores_sum, scores_num = tf.while_loop(stop_condition, step, init, parallel_iterations=1, back_prop=False) # We handle force_beginning_resets differently. We assume that all envs are # reseted at the end of episod (though it happens at the beginning of the # next one scores_num = tf.cond(force_beginning_resets, lambda: scores_num + len(batch_env), lambda: scores_num) with tf.control_dependencies([scores_sum]): scores_sum = tf.cond( force_beginning_resets, lambda: scores_sum + tf.reduce_sum( cumulative_rewards.read_value()), lambda: scores_sum) mean_score = tf.cond(tf.greater(scores_num, 0), lambda: scores_sum / tf.cast(scores_num, tf.float32), lambda: 0.) printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ") with tf.control_dependencies([index, printing]): memory = [mem.read_value() for mem in memory] # When generating real data together with PPO training we must use single # agent. For PPO to work we reshape the history, as if it was generated # by real_ppo_effective_num_agents. if ppo_hparams.effective_num_agents is not None and not eval_phase: new_memory = [] effective_num_agents = ppo_hparams.effective_num_agents assert epoch_length % ppo_hparams.effective_num_agents == 0, ( "The rollout of ppo_hparams.epoch_length will be distributed amongst" "effective_num_agents of agents") new_epoch_length = int(epoch_length / effective_num_agents) for mem, info in zip(memory, rollout_metadata): shape, _, name = info new_shape = [effective_num_agents, new_epoch_length ] + shape[1:] perm = list(range(len(shape) + 1)) perm[0] = 1 perm[1] = 0 mem = tf.transpose(mem, perm=perm) mem = tf.reshape(mem, shape=new_shape) mem = tf.transpose(mem, perm=perm, name="collect_memory_%d_%s" % (new_epoch_length, name)) new_memory.append(mem) memory = new_memory with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): mean_score_summary = tf.cond( tf.greater(scores_num, 0), lambda: tf.summary.scalar("mean_score_this_iter", mean_score), str) summaries = tf.summary.merge([ mean_score_summary, tf.summary.scalar("episodes_finished_this_iter", scores_num) ]) return memory, summaries, initialization_lambda
def maybe_gen_fake_data_based_on_real_data(image, label, reso, min_fake_lesion_ratio, gen_fake_probability): """Remove real lesion and synthesize lesion.""" # TODO(lehou): Replace magic numbers with flag variables. gen_prob_indicator = tf.random_uniform(shape=[], minval=0.0, maxval=1.0, dtype=tf.float32) background_mask = tf.less(label, 0.5) lesion_mask = tf.greater(label, 1.5) liver_mask = tf.logical_not(tf.logical_or(background_mask, lesion_mask)) liver_intensity = tf.boolean_mask(image, liver_mask) lesion_intensity = tf.boolean_mask(image, lesion_mask) intensity_diff = tf.reduce_mean(liver_intensity) - ( tf.reduce_mean(lesion_intensity)) intensity_diff *= 1.15 intensity_diff = tf.cond(tf.is_nan(intensity_diff), lambda: 0.0, lambda: intensity_diff) lesion_liver_ratio = 0.0 lesion_liver_ratio += tf.random.normal(shape=[], mean=0.01, stddev=0.01) lesion_liver_ratio += tf.random.normal(shape=[], mean=0.0, stddev=0.05) lesion_liver_ratio = tf.clip_by_value(lesion_liver_ratio, min_fake_lesion_ratio, min_fake_lesion_ratio + 0.20) fake_lesion_mask = tf.logical_and( _gen_rand_mask(ratio_mean=lesion_liver_ratio, ratio_stddev=0.0, scale=reso // 32, shape=label.shape, smoothness=reso // 32), tf.logical_not(background_mask)) liver_mask = tf.logical_not( tf.logical_or(background_mask, fake_lesion_mask)) # Blur the masks lesion_mask_blur = tf.squeeze( tf.nn.conv3d(tf.expand_dims( tf.expand_dims(tf.cast(lesion_mask, tf.float32), -1), 0), filter=tf.ones([reso // 32] * 3 + [1, 1], tf.float32) / (reso // 32)**3, strides=[1, 1, 1, 1, 1], padding='SAME')) fake_lesion_mask_blur = tf.squeeze( tf.nn.conv3d(tf.expand_dims( tf.expand_dims(tf.cast(fake_lesion_mask, tf.float32), -1), 0), filter=tf.ones([reso // 32] * 3 + [1, 1], tf.float32) / (reso // 32)**3, strides=[1, 1, 1, 1, 1], padding='SAME')) # Remove real lesion and add fake lesion. # If the intensitify is too small (maybe no liver or lesion region labeled), # do not generate fake data. gen_prob_indicator = tf.cond(tf.greater(intensity_diff, 0.0001), lambda: gen_prob_indicator, lambda: 0.0) # pylint: disable=g-long-lambda image = tf.cond( tf.greater(gen_prob_indicator, 1 - gen_fake_probability), lambda: image + intensity_diff * lesion_mask_blur \ - intensity_diff * fake_lesion_mask_blur, lambda: image) label = tf.cond( tf.greater(gen_prob_indicator, 1 - gen_fake_probability), lambda: tf.cast(background_mask, tf.float32) * 0 + \ tf.cast(liver_mask, tf.float32) * 1 + \ tf.cast(fake_lesion_mask, tf.float32) * 2, lambda: label) # pylint: enable=g-long-lambda return image, label
def has_nan(self): return tf.logical_or(tf.math.is_nan(self.x), tf.math.is_nan(self.y))
def logical_or(self, x, y): return tf.logical_or(x, y)
def parse1_func(filename): # read data dtype = tf.float32 image = tf.read_file(filename) image = tf.image.decode_image(image, channels=channels) shape = tf.shape(image) height = shape[-3] width = shape[-2] # pre down-scale for high resolution image dscale = 1 if is_training and config.pre_down: ''' if (width >= 3072 and height >= 1536) or (width >= 1536 and height >= 3072): dscale = 3 elif (width >= 1024 and height >= 512) or (width >= 512 and height >= 1024): dscale = 2 ''' def c_t(const1, const2, true_fn, false_fn): return tf.cond(tf.logical_or( tf.logical_and( tf.greater_equal(width, const1), tf.greater_equal(height, const2) ), tf.logical_and( tf.greater_equal(width, const2), tf.greater_equal(height, const1) ) ), true_fn, false_fn) dscale = c_t(3072, 1536, lambda: 3, lambda: c_t(1024, 512, lambda: 2, lambda: 1) ) elif is_testing and config.pre_down: ''' if (width >= 3072 and height >= 3072): dscale = 4 elif (width >= 2048 and height >= 2048): dscale = 3 elif (width >= 1024 and height >= 1024): dscale = 2 ''' def c_t(const1, true_fn, false_fn): return tf.cond(tf.logical_and( tf.greater_equal(width, const1), tf.greater_equal(height, const1) ), true_fn, false_fn) dscale = c_t(3072, lambda: 4, lambda: c_t(2048, lambda: 3, lambda: c_t(1024, lambda: 2, lambda: 1) ) ) # padding cropped_height = patch_height * dscale cropped_width = patch_width * dscale ''' if cropped_height > height or cropped_width > width: pad_height = cropped_height - height pad_width = cropped_width - width if pad_height > 0: pad_height = [pad_height // 2, pad_height - pad_height // 2] height = cropped_height else: pad_height = [0, 0] if pad_width > 0: pad_width = [pad_width // 2, pad_width - pad_width // 2] width = cropped_width else: pad_width = [0, 0] block = tf.pad(image, [pad_height, pad_width, [0, 0]], mode='REFLECT') else: block = image ''' cond_height = tf.greater(cropped_height, height) cond_width = tf.greater(cropped_width, width) def c_f1(): def _1(): ph = cropped_height - height return [ph // 2, ph - ph // 2] pad_height = tf.cond(cond_height, _1, lambda: [0, 0]) def _2(): pw = cropped_width - width return [pw // 2, pw - pw // 2] pad_width = tf.cond(cond_width, _2, lambda: [0, 0]) return tf.pad(image, [pad_height, pad_width, [0, 0]], mode='REFLECT') block = tf.cond(tf.logical_or(cond_height, cond_width), c_f1, lambda: image) height = tf.maximum(cropped_height, height) width = tf.maximum(cropped_width, width) # cropping if is_training: block = tf.random_crop(block, [cropped_height, cropped_width, channels]) block = tf.image.random_flip_up_down(block) block = tf.image.random_flip_left_right(block) elif is_testing: offset_height = (height - cropped_height) // 2 offset_width = (width - cropped_width) // 2 block = tf.image.crop_to_bounding_box(block, offset_height, offset_width, cropped_height, cropped_width) # convert dtype block = tf.image.convert_image_dtype(block, dtype, saturate=False) # random color augmentation if is_training and config.color_augmentation > 0: block = tf.image.random_saturation(block, 1 - config.color_augmentation, 1 + config.color_augmentation) block = tf.image.random_brightness(block, config.color_augmentation) block = tf.image.random_contrast(block, 1 - config.color_augmentation, 1 + config.color_augmentation) # data format conversion block.set_shape([None, None, channels]) if data_format == 'NCHW': block = tf.transpose(block, (2, 0, 1)) # return return block
def prepare_encoder_input(features, hparams, embed_scope=None, embed_token_fn=common_embed.embed_tokens): """Prepares the input for the screen encoder. Args: features: the feature dict. hparams: the hyperparameter. embed_scope: the embedding variable scope. embed_token_fn: the function for embedding tokens. Returns: object_embedding: a Tensor of shape [batch_size, num_steps, max_object_count, embed_depth] object_mask: a binary tensor of shape [batch_size, num_steps, max_object_count] nonpadding_bias: a Tensor of shape [batch_size, num_steps, max_object_count] """ with tf.control_dependencies( [tf.assert_equal(tf.rank(features["obj_text"]), 4)]): if hparams.get("synthetic_screen_noise", 0.) > 0.: num_objects = tf.shape(features["obj_text"])[2] # [batch, length, num_objects] target_obj_mask = tf.cast( tf.one_hot(features["objects"], depth=num_objects), tf.bool) num_tokens = tf.shape(features["obj_text"])[-1] target_obj_mask = tf.tile(tf.expand_dims(target_obj_mask, 3), [1, 1, 1, num_tokens]) # Randomly keep tokens keep_mask = tf.greater_equal( tf.random_uniform(shape=tf.shape(features["obj_text"])), hparams.synthetic_screen_noise) # Keep paddings keep_mask = tf.logical_or(tf.equal(features["obj_text"], 0), keep_mask) # Keep targets target_obj_mask = tf.logical_or(target_obj_mask, keep_mask) features["obj_text"] = tf.where( target_obj_mask, features["obj_text"], tf.random_uniform(shape=tf.shape(features["obj_text"]), maxval=50000, dtype=tf.int32)) text_embeddings, _ = embed_token_fn(features["obj_text"], hparams.task_vocab_size, hparams.hidden_size, hparams, embed_scope=embed_scope) with tf.variable_scope("obj_text_embed", reuse=tf.AUTO_REUSE): if hparams.obj_text_aggregation == "max": embed_bias = tf.cast(tf.less(features["obj_text"], 2), tf.float32) * -1e7 with tf.control_dependencies( [tf.assert_equal(tf.rank(embed_bias), 4)]): text_embeddings = tf.reduce_max( text_embeddings + tf.expand_dims(embed_bias, 4), -2) no_txt_embed = tf.get_variable(name="no_txt_embed", shape=[hparams.hidden_size]) shape = common_layers.shape_list(text_embeddings) no_txt_embed = tf.tile( tf.reshape(no_txt_embed, [1, 1, 1, hparams.hidden_size]), [shape[0], shape[1], shape[2], 1]) text_embeddings = tf.maximum(text_embeddings, no_txt_embed) elif hparams.obj_text_aggregation == "sum": # [batch, step, #max_obj, #max_token] 0 for padded tokens real_objects = tf.cast( tf.greater_equal(features["obj_text"], 2), tf.float32) # [batch, step, #max_obj, hidden] 0s for padded objects text_embeddings = tf.reduce_sum( text_embeddings * tf.expand_dims(real_objects, 4), -2) elif hparams.obj_text_aggregation == "mean": shape_list = common_layers.shape_list(text_embeddings) embeddings = tf.reshape(text_embeddings, [-1] + shape_list[3:]) emb_sum = tf.reduce_sum(tf.abs(embeddings), axis=-1) non_paddings = tf.not_equal(emb_sum, 0.0) embeddings = common_embed.average_bag_of_embeds( embeddings, non_paddings, use_bigrams=True, bigram_embed_scope=embed_scope, append_start_end=True) text_embeddings = tf.reshape( embeddings, shape_list[:3] + [hparams.hidden_size]) else: raise ValueError("Unrecognized token aggregation %s" % (hparams.obj_text_aggregation)) with tf.control_dependencies([ tf.assert_equal(tf.rank(features["obj_type"]), 3), tf.assert_equal(tf.rank(features["obj_clickable"]), 3) ]): with tf.variable_scope("encode_object_attr", reuse=tf.AUTO_REUSE): type_embedding = tf.nn.embedding_lookup(params=tf.get_variable( name="embed_type_w", shape=[hparams.get("num_types", 100), hparams.hidden_size]), ids=tf.maximum( features["obj_type"], 0)) clickable_embedding = tf.nn.embedding_lookup( params=tf.get_variable(name="embed_clickable_w", shape=[2, hparams.hidden_size]), ids=features["obj_clickable"]) with tf.control_dependencies( [tf.assert_equal(tf.rank(features["obj_screen_pos"]), 4)]): def _create_embed(feature_name, vocab_size, depth): """Embed a position feature.""" pos_embedding_list = [] with tf.variable_scope("encode_object_" + feature_name, reuse=tf.AUTO_REUSE): num_featues = common_layers.shape_list( features[feature_name])[-1] for i in range(num_featues): pos_embedding_list.append( tf.nn.embedding_lookup( params=tf.get_variable(name=feature_name + "_embed_w_%d" % i, shape=[vocab_size, depth]), ids=features[feature_name][:, :, :, i])) pos_embedding = tf.add_n(pos_embedding_list) return pos_embedding pos_embedding = _create_embed("obj_screen_pos", hparams.max_pixel_pos, hparams.hidden_size) if "all" == hparams.screen_embedding_feature or ( "dom" in hparams.screen_embedding_feature): dom_embedding = _create_embed("obj_dom_pos", hparams.max_dom_pos, hparams.hidden_size) object_embed = tf.zeros_like(text_embeddings, dtype=tf.float32) if hparams.screen_embedding_feature == "all": object_embed = (text_embeddings + type_embedding + pos_embedding + dom_embedding) elif "text" in hparams.screen_embedding_feature: object_embed += text_embeddings elif "type" in hparams.screen_embedding_feature: object_embed += type_embedding elif "pos" in hparams.screen_embedding_feature: object_embed += pos_embedding elif "dom" in hparams.screen_embedding_feature: object_embed += dom_embedding elif "click" in hparams.screen_embedding_feature: object_embed += clickable_embedding object_mask = tf.cast(tf.not_equal(features["obj_type"], -1), tf.float32) object_embed = object_embed * tf.expand_dims(object_mask, 3) att_bias = (1. - object_mask) * common_attention.large_compatible_negative( object_embed.dtype) return object_embed, object_mask, att_bias
def trilerp_gather(vol, inds, bad_inds=None): """Trilinear interpolation dense gather from volume at query inds.""" inds_b = inds[Ellipsis, 0] inds_x = inds[Ellipsis, 1] inds_y = inds[Ellipsis, 2] inds_z = inds[Ellipsis, 3] inds_x_0 = tf.floor(inds_x) inds_x_1 = inds_x_0 + 1 inds_y_0 = tf.floor(inds_y) inds_y_1 = inds_y_0 + 1 inds_z_0 = tf.floor(inds_z) inds_z_1 = inds_z_0 + 1 # store invalid indices to implement correct out-of-bounds conditions invalid_x = tf.logical_or( tf.less(inds_x_0, 0.0), tf.greater(inds_x_1, tf.to_float(tf.shape(vol)[2] - 1))) invalid_y = tf.logical_or( tf.less(inds_y_0, 0.0), tf.greater(inds_y_1, tf.to_float(tf.shape(vol)[1] - 1))) invalid_z = tf.logical_or( tf.less(inds_z_0, 0.0), tf.greater(inds_z_1, tf.to_float(tf.shape(vol)[3] - 1))) if bad_inds is not None: invalid_inds = tf.logical_or( tf.logical_or(tf.logical_or(invalid_x, invalid_y), invalid_z), bad_inds) else: invalid_inds = tf.logical_or(tf.logical_or(invalid_x, invalid_y), invalid_z) inds_x_0 = tf.clip_by_value(inds_x_0, 0.0, tf.to_float(tf.shape(vol)[2] - 2)) inds_x_1 = tf.clip_by_value(inds_x_1, 0.0, tf.to_float(tf.shape(vol)[2] - 1)) inds_y_0 = tf.clip_by_value(inds_y_0, 0.0, tf.to_float(tf.shape(vol)[1] - 2)) inds_y_1 = tf.clip_by_value(inds_y_1, 0.0, tf.to_float(tf.shape(vol)[1] - 1)) inds_z_0 = tf.clip_by_value(inds_z_0, 0.0, tf.to_float(tf.shape(vol)[3] - 2)) inds_z_1 = tf.clip_by_value(inds_z_1, 0.0, tf.to_float(tf.shape(vol)[3] - 1)) # compute interp weights w_x_0 = 1.0 - (inds_x - inds_x_0) w_x_1 = 1.0 - w_x_0 w_y_0 = 1.0 - (inds_y - inds_y_0) w_y_1 = 1.0 - w_y_0 w_z_0 = 1.0 - (inds_z - inds_z_0) w_z_1 = 1.0 - w_z_0 w_0_0_0 = w_y_0 * w_x_0 * w_z_0 w_1_0_0 = w_y_1 * w_x_0 * w_z_0 w_0_1_0 = w_y_0 * w_x_1 * w_z_0 w_0_0_1 = w_y_0 * w_x_0 * w_z_1 w_1_1_0 = w_y_1 * w_x_1 * w_z_0 w_0_1_1 = w_y_0 * w_x_1 * w_z_1 w_1_0_1 = w_y_1 * w_x_0 * w_z_1 w_1_1_1 = w_y_1 * w_x_1 * w_z_1 # gather for interp inds_0_0_0 = tf.to_int32( tf.stack([inds_b, inds_y_0, inds_x_0, inds_z_0], axis=-1)) inds_1_0_0 = tf.to_int32( tf.stack([inds_b, inds_y_1, inds_x_0, inds_z_0], axis=-1)) inds_0_1_0 = tf.to_int32( tf.stack([inds_b, inds_y_0, inds_x_1, inds_z_0], axis=-1)) inds_0_0_1 = tf.to_int32( tf.stack([inds_b, inds_y_0, inds_x_0, inds_z_1], axis=-1)) inds_1_1_0 = tf.to_int32( tf.stack([inds_b, inds_y_1, inds_x_1, inds_z_0], axis=-1)) inds_0_1_1 = tf.to_int32( tf.stack([inds_b, inds_y_0, inds_x_1, inds_z_1], axis=-1)) inds_1_0_1 = tf.to_int32( tf.stack([inds_b, inds_y_1, inds_x_0, inds_z_1], axis=-1)) inds_1_1_1 = tf.to_int32( tf.stack([inds_b, inds_y_1, inds_x_1, inds_z_1], axis=-1)) vol_0_0_0 = tf.gather_nd(vol, inds_0_0_0) * w_0_0_0[Ellipsis, tf.newaxis] vol_1_0_0 = tf.gather_nd(vol, inds_1_0_0) * w_1_0_0[Ellipsis, tf.newaxis] vol_0_1_0 = tf.gather_nd(vol, inds_0_1_0) * w_0_1_0[Ellipsis, tf.newaxis] vol_0_0_1 = tf.gather_nd(vol, inds_0_0_1) * w_0_0_1[Ellipsis, tf.newaxis] vol_1_1_0 = tf.gather_nd(vol, inds_1_1_0) * w_1_1_0[Ellipsis, tf.newaxis] vol_0_1_1 = tf.gather_nd(vol, inds_0_1_1) * w_0_1_1[Ellipsis, tf.newaxis] vol_1_0_1 = tf.gather_nd(vol, inds_1_0_1) * w_1_0_1[Ellipsis, tf.newaxis] vol_1_1_1 = tf.gather_nd(vol, inds_1_1_1) * w_1_1_1[Ellipsis, tf.newaxis] out_vol = vol_0_0_0 + vol_1_0_0 + vol_0_1_0 + vol_0_0_1 + \ vol_1_1_0 + vol_0_1_1 + vol_1_0_1 + vol_1_1_1 # boundary conditions for invalid indices invalid_inds = tf.tile(invalid_inds[:, :, :, :, tf.newaxis], [1, 1, 1, 1, tf.shape(vol)[4]]) out_vol = tf.where(invalid_inds, tf.zeros_like(out_vol), out_vol) return out_vol
def bilerp_gather(img, inds): """Bilinear interpolation dense gather from image at query inds.""" inds_b, _, _, = tf.meshgrid( tf.range(tf.shape(img)[0]), tf.range(tf.shape(img)[1]), tf.range(tf.shape(img)[2]), indexing='ij') inds_b = tf.to_float(inds_b) inds_x = inds[Ellipsis, 0] inds_y = inds[Ellipsis, 1] inds_x_0 = tf.floor(inds_x) inds_x_1 = inds_x_0 + 1 inds_y_0 = tf.floor(inds_y) inds_y_1 = inds_y_0 + 1 # store invalid indices to implement correct out-of-bounds conditions invalid_x = tf.logical_or( tf.less(inds_x_0, 0.0), tf.greater(inds_x_1, tf.to_float(tf.shape(img)[2] - 1))) invalid_y = tf.logical_or( tf.less(inds_y_0, 0.0), tf.greater(inds_y_1, tf.to_float(tf.shape(img)[1] - 1))) invalid_inds = tf.logical_or(invalid_x, invalid_y) inds_x_0 = tf.clip_by_value(inds_x_0, 0.0, tf.to_float(tf.shape(img)[2] - 2)) inds_x_1 = tf.clip_by_value(inds_x_1, 0.0, tf.to_float(tf.shape(img)[2] - 1)) inds_y_0 = tf.clip_by_value(inds_y_0, 0.0, tf.to_float(tf.shape(img)[1] - 2)) inds_y_1 = tf.clip_by_value(inds_y_1, 0.0, tf.to_float(tf.shape(img)[1] - 1)) # compute interp weights w_x_0 = 1.0 - (inds_x - inds_x_0) w_x_1 = 1.0 - w_x_0 w_y_0 = 1.0 - (inds_y - inds_y_0) w_y_1 = 1.0 - w_y_0 w_0_0 = w_y_0 * w_x_0 w_1_0 = w_y_1 * w_x_0 w_0_1 = w_y_0 * w_x_1 w_1_1 = w_y_1 * w_x_1 # gather for interp inds_0_0 = tf.to_int32(tf.stack([inds_b, inds_y_0, inds_x_0], axis=-1)) inds_1_0 = tf.to_int32(tf.stack([inds_b, inds_y_1, inds_x_0], axis=-1)) inds_0_1 = tf.to_int32(tf.stack([inds_b, inds_y_0, inds_x_1], axis=-1)) inds_1_1 = tf.to_int32(tf.stack([inds_b, inds_y_1, inds_x_1], axis=-1)) img_0_0 = tf.gather_nd(img, inds_0_0) * w_0_0[Ellipsis, tf.newaxis] img_1_0 = tf.gather_nd(img, inds_1_0) * w_1_0[Ellipsis, tf.newaxis] img_0_1 = tf.gather_nd(img, inds_0_1) * w_0_1[Ellipsis, tf.newaxis] img_1_1 = tf.gather_nd(img, inds_1_1) * w_1_1[Ellipsis, tf.newaxis] out_img = img_0_0 + img_1_0 + img_0_1 + img_1_1 # boundary conditions for invalid indices invalid_inds = tf.tile(invalid_inds[:, :, :, tf.newaxis], [1, 1, 1, tf.shape(img)[3]]) out_img = tf.where(invalid_inds, tf.zeros_like(out_img), out_img) return out_img
def parser(value): """Parse an Imagenet record from value.""" keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/class/label': tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64), } parsed = tf.parse_single_example(value, keys_to_features) encoded_image = tf.reshape(parsed['image/encoded'], shape=[], name='encoded_image') image_format = parsed['image/format'] xmin = tf.expand_dims(parsed['image/object/bbox/xmin'].values, 0) ymin = tf.expand_dims(parsed['image/object/bbox/ymin'].values, 0) xmax = tf.expand_dims(parsed['image/object/bbox/xmax'].values, 0) ymax = tf.expand_dims(parsed['image/object/bbox/ymax'].values, 0) # Note that we impose an ordering of (y, x) just to make life difficult. bbox = tf.concat([ymin, xmin, ymax, xmax], 0) # Force the variable number of bounding boxes into the shape # [1, num_boxes, coords]. bbox = tf.expand_dims(bbox, 0) bbox = tf.transpose(bbox, [0, 2, 1]) def decode_png(): return tf.image.decode_png(encoded_image, 3) def decode_jpg(): return tf.image.decode_jpeg(encoded_image, 3) # If image format is PNG, use decode_png, default to jpg. pred_fn_pairs = { tf.logical_or(tf.equal(image_format, 'png'), tf.equal(image_format, 'PNG')): decode_png } image = tf.case(pred_fn_pairs, default=decode_jpg, exclusive=True) image.set_shape([None, None, 3]) image = preprocess(image, bbox) label = tf.cast(tf.reshape(parsed['image/class/label'], shape=[]), dtype=tf.int32, name='cast_label') label = tf.reshape(label, [1]) return tf.cast(image, tf.float32), label
def should_log(params): """Returns a Boolean `tf.Tensor` dictating whether we should log values.""" global_step = tf.train.get_or_create_global_step() first_run = tf.equal(global_step, 1) log_every = tf.equal(tf.floormod(global_step, params.log_every), 0) return tf.logical_or(first_run, log_every)
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9): """ Does top-p sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param p: topp threshold to use, either a float or a [batch_size] vector :return: [batch_size, num_samples] samples # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK """ with tf.variable_scope('top_p_sample'): batch_size, vocab_size = get_shape_list(logits, expected_rank=2) probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, axis=-1) if isinstance(p, float) and p > 0.999999: # Don't do top-p sampling in this case print("Top-p sampling DISABLED", flush=True) return { 'probs': probs, 'sample': tf.random.categorical( logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, num_samples=num_samples, dtype=tf.int32), } # [batch_size, vocab_perm] indices = tf.argsort(probs, direction='DESCENDING') cumulative_probabilities = tf.math.cumsum(tf.batch_gather( probs, indices), axis=-1, exclusive=False) # find the top pth index to cut off. careful we don't want to cutoff everything! # result will be [batch_size, vocab_perm] p_expanded = p if isinstance(p, float) else p[:, None] exclude_mask = tf.logical_not( tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1)) # OPTION A - sample in the sorted space, then unsort. logits_to_use = tf.batch_gather( logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) sample = tf.batch_gather(indices, sample_perm) # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample # unperm_indices = tf.argsort(indices, direction='ASCENDING') # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices) # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10 # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32) return { 'probs': probs, 'sample': sample, }
def assign_and_sample_proposals(proposed_boxes, gt_boxes, gt_classes, gt_attributes, num_samples_per_image=512, mix_gt_boxes=True, fg_fraction=0.25, fg_iou_thresh=0.5, bg_iou_thresh_hi=0.5, bg_iou_thresh_lo=0.0): """Assigns the proposals with groundtruth classes and performs subsmpling. Given `proposed_boxes`, `gt_boxes`, `gt_classes` and `gt_attributes`, the function uses the following algorithm to generate the final `num_samples_per_image` RoIs. 1. Calculates the IoU between each proposal box and each gt_boxes. 2. Assigns each proposed box with a groundtruth class and box by choosing the largest IoU overlap. 3. Samples `num_samples_per_image` boxes from all proposed boxes, and returns box_targets, class_targets, and RoIs. Args: proposed_boxes: a tensor of shape of [batch_size, N, 4]. N is the number of proposals before groundtruth assignment. The last dimension is the box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax] format. gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4]. The coordinates of gt_boxes are in the pixel coordinates of the scaled image. This tensor might have padding of values -1 indicating the invalid box coordinates. gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This tensor might have paddings with values of -1 indicating the invalid classes. gt_attributes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, num_attributes]. This tensor might have paddings with values of -1 indicating the invalid attributes. num_samples_per_image: an integer represents RoI minibatch size per image. mix_gt_boxes: a bool indicating whether to mix the groundtruth boxes before sampling proposals. fg_fraction: a float represents the target fraction of RoI minibatch that is labeled foreground (i.e., class > 0). fg_iou_thresh: a float represents the IoU overlap threshold for an RoI to be considered foreground (if >= fg_iou_thresh). bg_iou_thresh_hi: a float represents the IoU overlap threshold for an RoI to be considered background (class = 0 if overlap in [LO, HI)). bg_iou_thresh_lo: a float represents the IoU overlap threshold for an RoI to be considered background (class = 0 if overlap in [LO, HI)). Returns: sampled_rois: a tensor of shape of [batch_size, K, 4], representing the coordinates of the sampled RoIs, where K is the number of the sampled RoIs, i.e. K = num_samples_per_image. sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the box coordinates of the matched groundtruth boxes of the samples RoIs. sampled_gt_classes: a tensor of shape of [batch_size, K], storing the classes of the matched groundtruth boxes of the sampled RoIs. sampled_gt_attributes: a tensor of shape of [batch_size, K, num_attributes], storing the attributes of the matched groundtruth attributes of the sampled RoIs. sampled_gt_indices: a tensor of shape of [batch_size, K], storing the indices of the sampled groudntruth boxes in the original `gt_boxes` tensor, i.e. gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i]. """ with tf.name_scope('sample_proposals'): if mix_gt_boxes: boxes = tf.concat([proposed_boxes, gt_boxes], axis=1) else: boxes = proposed_boxes (matched_gt_boxes, matched_gt_classes, matched_gt_attributes, matched_gt_indices, matched_iou, _) = box_matching(boxes, gt_boxes, gt_classes, gt_attributes) positive_match = tf.greater(matched_iou, fg_iou_thresh) negative_match = tf.logical_and( tf.greater_equal(matched_iou, bg_iou_thresh_lo), tf.less(matched_iou, bg_iou_thresh_hi)) ignored_match = tf.less(matched_iou, 0.0) # re-assign negatively matched boxes to the background class. matched_gt_classes = tf.where(negative_match, tf.zeros_like(matched_gt_classes), matched_gt_classes) matched_gt_indices = tf.where(negative_match, tf.zeros_like(matched_gt_indices), matched_gt_indices) sample_candidates = tf.logical_and( tf.logical_or(positive_match, negative_match), tf.logical_not(ignored_match)) sampler = ( balanced_positive_negative_sampler.BalancedPositiveNegativeSampler( positive_fraction=fg_fraction, is_static=True)) batch_size, _ = sample_candidates.get_shape().as_list() sampled_indicators = [] for i in range(batch_size): sampled_indicator = sampler.subsample(sample_candidates[i], num_samples_per_image, positive_match[i]) sampled_indicators.append(sampled_indicator) sampled_indicators = tf.stack(sampled_indicators) _, sampled_indices = tf.nn.top_k(tf.cast(sampled_indicators, dtype=tf.int32), k=num_samples_per_image, sorted=True) sampled_indices_shape = tf.shape(sampled_indices) batch_indices = ( tf.expand_dims(tf.range(sampled_indices_shape[0]), axis=-1) * tf.ones([1, sampled_indices_shape[-1]], dtype=tf.int32)) gather_nd_indices = tf.stack([batch_indices, sampled_indices], axis=-1) sampled_rois = tf.gather_nd(boxes, gather_nd_indices) sampled_gt_boxes = tf.gather_nd(matched_gt_boxes, gather_nd_indices) sampled_gt_classes = tf.gather_nd(matched_gt_classes, gather_nd_indices) sampled_gt_attributes = tf.gather_nd(matched_gt_attributes, gather_nd_indices) sampled_gt_indices = tf.gather_nd(matched_gt_indices, gather_nd_indices) return (sampled_rois, sampled_gt_boxes, sampled_gt_classes, sampled_gt_attributes, sampled_gt_indices)
def get_retrieval_examples(serialized_example, mask_rate, bert_hub_module_path, query_seq_len, block_seq_len): """Make retrieval examples.""" feature_spec = dict(title_ids=tf.FixedLenSequenceFeature([], tf.int64, True), token_ids=tf.FixedLenSequenceFeature([], tf.int64, True), sentence_starts=tf.FixedLenSequenceFeature([], tf.int64, True)) features = tf.parse_single_example(serialized_example, feature_spec) features = {k: tf.cast(v, tf.int32) for k, v in features.items()} title_ids = features["title_ids"] token_ids = features["token_ids"] sentence_starts = features["sentence_starts"] sentence_ends = tf.concat([sentence_starts[1:], [tf.size(token_ids)]], 0) tokenizer = bert_utils.get_tokenizer(bert_hub_module_path) cls_id, sep_id = tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]"]) # Randomly choose a sentence and pretend that it is a query. query_index = tf.random.uniform(shape=[], minval=0, maxval=tf.size(sentence_starts), dtype=tf.int32) query_start = sentence_starts[query_index] query_end = sentence_ends[query_index] query_ids = token_ids[query_start:query_end] mask_query = tf.less(tf.random.uniform([]), mask_rate) def _apply_mask(): return tf.concat([token_ids[:query_start], token_ids[query_end:]], 0) block_ids = tf.cond(pred=mask_query, true_fn=_apply_mask, false_fn=lambda: token_ids) query_ids, query_mask = bert_utils.pad_or_truncate( token_ids=query_ids, sequence_length=query_seq_len, cls_id=cls_id, sep_id=sep_id) block_ids, block_mask, block_segment_ids = bert_utils.pad_or_truncate_pair( token_ids_a=title_ids, token_ids_b=block_ids, sequence_length=block_seq_len, cls_id=cls_id, sep_id=sep_id) # Masked examples for single-sentence blocks don't make any sense. keep_example = tf.logical_or(tf.logical_not(mask_query), tf.greater(tf.size(sentence_starts), 1)) return dict(keep_example=keep_example, mask_query=mask_query, query_ids=query_ids, query_mask=query_mask, block_ids=block_ids, block_mask=block_mask, block_segment_ids=block_segment_ids)