def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step): """Add all summaries.""" for k in scalars.keys(): if not isinstance(scalars[k], tf.Tensor): scalars[k] = tf.cast( lowering.export_to_tf_tensor(scalars[k]), tf.float32) def _host_loss_summary(global_step, tf_loss, **scalars): """Add summary.scalar in host side.""" gs = tf.cast(global_step, tf.int64) sum_loss = contrib_summary.scalar( '{}_loss'.format(train_or_eval), tf_loss, step=gs) sum_ops = [sum_loss.op] for description, tf_metric in scalars.iteritems(): sum_metric = contrib_summary.scalar( '{}_{}'.format(train_or_eval, description), tf_metric, step=gs) sum_ops.append(sum_metric) with tf.control_dependencies(sum_ops): return tf.identity(tf_loss) if FLAGS.use_tpu: # Cast the global step to tf.int32, since # outside_compilation does not support tf.int64. tf_loss = tpu.outside_compilation( _host_loss_summary, tf.cast(global_step, tf.int32), tf_loss, **scalars) else: tf_loss = _host_loss_summary( tf.cast(global_step, tf.int32), tf_loss, **scalars) return tf_loss
def compute_eval_dict(features, labels): """Compute the evaluation result on an image.""" # For evaling on train data, it is necessary to check whether groundtruth # must be unpadded. boxes_shape = (labels[ fields.InputDataFields.groundtruth_boxes].get_shape().as_list()) unpad_groundtruth_tensors = (boxes_shape[1] is not None and not use_tpu and batch_size == 1) labels = model_lib.unstack_batch( labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) losses_dict, prediction_dict = _compute_losses_and_predictions_dicts( detection_model, features, labels, add_regularization_loss) def postprocess_wrapper(args): return detection_model.postprocess(args[0], args[1]) # TODO(kaftan): Depending on how postprocessing will work for TPUS w/ ## TPUStrategy, may be good to move wrapping to a utility method if use_tpu and postprocess_on_cpu: detections = contrib_tpu.outside_compilation( postprocess_wrapper, (prediction_dict, features[fields.InputDataFields.true_image_shape])) else: detections = postprocess_wrapper( (prediction_dict, features[fields.InputDataFields.true_image_shape])) class_agnostic = (fields.DetectionResultFields.detection_classes not in detections) # TODO(kaftan) (or anyone): move `_prepare_groundtruth_for_eval to eval_util ## and call this from there. groundtruth = model_lib._prepare_groundtruth_for_eval( # pylint: disable=protected-access detection_model, class_agnostic, eval_input_config.max_number_of_boxes) use_original_images = fields.InputDataFields.original_image in features if use_original_images: eval_images = features[fields.InputDataFields.original_image] true_image_shapes = tf.slice( features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3]) original_image_spatial_shapes = features[ fields.InputDataFields.original_image_spatial_shape] else: eval_images = features[fields.InputDataFields.image] true_image_shapes = None original_image_spatial_shapes = None eval_dict = eval_util.result_dict_for_batched_example( eval_images, features[inputs.HASH_KEY], detections, groundtruth, class_agnostic=class_agnostic, scale_to_absolute=True, original_image_spatial_shapes=original_image_spatial_shapes, true_image_shapes=true_image_shapes) return eval_dict, losses_dict, class_agnostic
def _model_fn(features, labels, mode, params, model, use_tpu_estimator_spec, variable_filter_fn=None): """Model defination for the RetinaNet model based on ResNet. Args: features: the input image tensor with shape [batch_size, height, width, 3]. The height and width are fixed and equal. labels: the input labels in a dictionary. The labels include class targets and box targets which are dense label maps. The labels are generated from get_input_fn function in dataloader.py mode: the mode of TPUEstimator/Estimator including TRAIN, EVAL, and PREDICT. params: the dictionary defines hyperparameters of model. The default settings are in default_hparams function in this file. model: the RetinaNet model outputs class logits and box regression outputs. use_tpu_estimator_spec: Whether to use TPUEstimatorSpec or EstimatorSpec. variable_filter_fn: the filter function that takes trainable_variables and returns the variable list after applying the filter rule. Returns: tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction. """ # In predict mode features is a dict with input as value of the 'inputs'. image_info = None if (mode == tf.estimator.ModeKeys.PREDICT and isinstance(features, dict) and 'inputs' in features): image_info = features['image_info'] labels = None if 'labels' in features: labels = features['labels'] features = features['inputs'] def _model_outputs(): return model(features, min_level=params['min_level'], max_level=params['max_level'], num_classes=params['num_classes'], num_anchors=len(params['aspect_ratios'] * params['num_scales']), resnet_depth=params['resnet_depth'], is_training_bn=params['is_training_bn']) if params['use_bfloat16']: with contrib_tpu.bfloat16_scope(): cls_outputs, box_outputs = _model_outputs() levels = cls_outputs.keys() for level in levels: cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32) box_outputs[level] = tf.cast(box_outputs[level], tf.float32) else: cls_outputs, box_outputs = _model_outputs() levels = cls_outputs.keys() # First check if it is in PREDICT mode. if mode == tf.estimator.ModeKeys.PREDICT: # Postprocess on host; memory layout for NMS on TPU is very inefficient. def _predict_postprocess_wrapper(args): return _predict_postprocess(*args) predictions = contrib_tpu.outside_compilation( _predict_postprocess_wrapper, (cls_outputs, box_outputs, labels, params)) # Include resizing information on prediction output to help bbox drawing. if image_info is not None: predictions.update({ 'image_info': tf.identity(image_info, 'ImageInfo'), }) return contrib_tpu.TPUEstimatorSpec(mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions) # Load pretrained model from checkpoint. if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN: def scaffold_fn(): """Loads pretrained model through scaffold function.""" tf.train.init_from_checkpoint( params['resnet_checkpoint'], { '/': 'resnet%s/' % params['resnet_depth'], }) return tf.train.Scaffold() else: scaffold_fn = None # Set up training loss and learning rate. update_learning_rate_schedule_parameters(params) global_step = tf.train.get_global_step() learning_rate = learning_rate_schedule(params['adjusted_learning_rate'], params['lr_warmup_init'], params['lr_warmup_step'], params['first_lr_drop_step'], params['second_lr_drop_step'], global_step) # cls_loss and box_loss are for logging. only total_loss is optimized. total_loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs, labels, params) total_loss += _WEIGHT_DECAY * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=params['momentum']) if params['use_tpu']: optimizer = contrib_tpu.CrossShardOptimizer(optimizer) else: if params['auto_mixed_precision']: optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer) # Batch norm requires `update_ops` to be executed alongside `train_op`. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) var_list = variable_filter_fn( tf.trainable_variables(), params['resnet_depth']) if variable_filter_fn else None minimize_op = optimizer.minimize(total_loss, global_step, var_list=var_list) train_op = tf.group(minimize_op, update_ops) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(**kwargs): """Returns a dictionary that has the evaluation metrics.""" batch_size = params['batch_size'] eval_anchors = anchors.Anchors(params['min_level'], params['max_level'], params['num_scales'], params['aspect_ratios'], params['anchor_scale'], params['image_size']) anchor_labeler = anchors.AnchorLabeler(eval_anchors, params['num_classes']) cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat']) box_loss = tf.metrics.mean(kwargs['box_loss_repeat']) coco_metrics = coco_metric_fn(batch_size, anchor_labeler, params['val_json_file'], **kwargs) # Add metrics to output. output_metrics = { 'cls_loss': cls_loss, 'box_loss': box_loss, } output_metrics.update(coco_metrics) return output_metrics cls_loss_repeat = tf.reshape( tf.tile(tf.expand_dims(cls_loss, 0), [ params['batch_size'], ]), [params['batch_size'], 1]) box_loss_repeat = tf.reshape( tf.tile(tf.expand_dims(box_loss, 0), [ params['batch_size'], ]), [params['batch_size'], 1]) metric_fn_inputs = { 'cls_loss_repeat': cls_loss_repeat, 'box_loss_repeat': box_loss_repeat, 'source_ids': labels['source_ids'], 'groundtruth_data': labels['groundtruth_data'], 'image_scales': labels['image_scales'], } add_metric_fn_inputs(params, cls_outputs, box_outputs, metric_fn_inputs) eval_metrics = (metric_fn, metric_fn_inputs) if use_tpu_estimator_spec: return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: return tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, # TODO(rostam): Fix bug to get scaffold working. # scaffold=scaffold_fn(), train_op=train_op)
def model_fn(features, labels, mode, config, params): """Estimator model function.""" # Not sure why it does this? del labels del config del params tf.get_variable_scope().set_initializer( tf.variance_scaling_initializer(1.0, mode="fan_avg", distribution="uniform")) # PREDICTION (e.g. evaluate) if mode == tf.estimator.ModeKeys.PREDICT: predictions, _, _ = model_params.estimator_prediction_fn(features) if include_features_in_predictions: predictions.update(features) if decode_keys: # Decode the raw ids into strings in prediction. def decode_host_call(tensor_dict): for key in decode_keys: predictions[key] = public_parsing_ops.decode( tensor_dict[key], model_params.vocab_filename, model_params.encoder_type) return tensor_dict contrib_tpu.outside_compilation(decode_host_call, predictions) return tpu_estimator.TPUEstimatorSpec(mode=mode, predictions=predictions) # TRAINING training = mode == tf.estimator.ModeKeys.TRAIN # use_tpu is false by default so this skips if use_tpu and model_params.use_bfloat16: with contrib_tpu.bfloat16_scope(): loss, outputs = model_params.model()(features, training) else: XENT_loss, outputs = model_params.model()(features, training) # XENT_loss, outputs = model_params.model().double_sampling(features, training, model_params.batch_size, # features["targets"].get_shape().as_list()[1], # mixed=True) # TPU requires outputs all have batch dimension and doesn't handle scalar. # Tile all scalars to 1 dimension vector. outputs = _tile_scalar_to_batch_size(outputs, model_params.batch_size) # Create optimizer and define learning rate if mode == tf.estimator.ModeKeys.TRAIN: init_lr = model_params.learning_rate global_step = tf.train.get_global_step() lr = init_lr / 0.01 * tf.rsqrt( tf.maximum(tf.to_float(global_step), 10000)) if train_init_checkpoint: lr = tf.minimum( tf.to_float(global_step + 1) / train_warmup_steps * init_lr, lr) optimizer = adafactor.AdafactorOptimizer( learning_rate=lr, decay_rate=adafactor.adafactor_decay_rate_pow(0.8), beta1=0.0) if use_tpu: optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) ############################################################################################################### ##### VARIABLES ############################################################################################### # Create index tensors to stack and get corresponding probabilities from logp # max_seq_len = outputs["targets"].get_shape().as_list()[1] # sequence_index = tf.constant(np.arange(0, max_seq_len)) # batch_index = tf.constant(np.zeros(sequence_index.get_shape().as_list()[0]), dtype=tf.int64) ##### I.I.D SAMPLING ########################################################################################## """ Here we sample the tokens that are produced by teacher forcing. """ # Normalise logits to log-prob, and compute Gumbel samples with location # logit_probs = tf.math.softmax(outputs["logits"], axis=2) # should not be x <= 0 # clipped_logit_probs = tf.clip_by_value(logit_probs, 1e-8, 1.0) # logp = tf.log(clipped_logit_probs) # RETURNS TEACHER FORCING SAMPLED TOKEN VARIATIONS # argmax_logp_index, soft_logp_index, topk_out, z = iid_sampling(logp, max_seq_len, greedy=True, soft=False, # topk=False, k=2) # topk_probs, topk_indices = topk_out # TEST SAMPLING METHODS PROVIDED BY PEGASUS # sampled_BxT = iid_process_logits(outputs["logits"], max_seq_len, model_params.batch_size, # outputs["logits"].get_shape().as_list()[-1], # top_k=0, top_p=0.9, temperature=1.0) ##### DECODER SAMPLING ######################################################################################## """ Here we sample the tokens using the decoder. Beam size == 1. PREDS: IDs LOGP: transformed logits SCORE: scalar score using RISK trick LOGP: [BxTxV] beam logp LOGITS: [BxTxV] beam logits the dictionary contains the following keys: {ids, logp_BxT, sent_score, logp_BxTxV} # Note: the logp_BxTxV are analogous to z -> should be used for RELAX, preds are the BxT of these -> b=H(z), and # logp are the corresponding values (score is normalised to sentence score). """ # greedy_beam_params = {"_beam": 3, "top_k": 0, "top_p": 0.0, "temperature": 0.0} # random_beam_params = {"_beam": 3, "top_k": 0, "top_p": 0.0, "temperature": 1.0} # topk_beam_params = {"_beam": 3, "top_k": 10000, "top_p": 0.0, "temperature": 1.0} # topp_beam_params = {"_beam": 3, "top_k": 0, "top_p": 0.9, "temperature": 1.0} # greedy_dict = non_beam_sampling(model_params, features, max_seq_len, # beam_params=greedy_beam_params, sentence_score=False) # random_dict = non_beam_sampling(model_params, features, max_seq_len, # beam_params=random_beam_params, sentence_score=False) # topk_dict = non_beam_sampling(model_params, features, max_seq_len, # beam_params=topk_beam_params, sentence_score=False) # topp_dict = non_beam_sampling(model_params, features, max_seq_len, # beam_params=topp_beam_params, sentence_score=False) # BEAM SEARCH # greedy_dict = beam_sampling(model_params, features, max_seq_len, batch_index, sequence_index, # beam_params=greedy_beam_params) # random_dict = beam_sampling(model_params, features, max_seq_len, batch_index, sequence_index, # beam_params=random_beam_params) # topk_dict = beam_sampling(model_params, features, max_seq_len, batch_index, sequence_index, # beam_params=topk_beam_params) # topp_dict = beam_sampling(model_params, features, max_seq_len, batch_index, sequence_index, # beam_params=topp_beam_params) ##### RELAX VARIABLES ######################################################################################### """ Here we create the variables for RELAX. Pass in the logp, logits, and z that has already been sampled/created from manipulation. Will return z_tilde [BxTxV] and logp(b) [BxT]. """ # TEACHER FORCING SAMPLING # z_tilde, logp_b = create_variables(z, logp, batch_index, sequence_index, clipped_logit_probs) # DECODER SAMPLING -> sample_b is already argmaxed in decode loop # z_tilde, logp_b = create_variables_from_samples(random_dict["logits_BxTxV"], random_dict["logp_BxTxV"], # random_dict["ids"], batch_index, sequence_index) ##### TEXT AND ROUGE ########################################################################################## """ Here we first convert sequences to text, and calculate corresponding rouge scores/losses. """ # target_text = rouge_decoding(outputs["targets"], model_params) # TARGET SAMPLES # argmax_pred_text = rouge_decoding(argmax_logp_index, model_params) # ARGMAX SAMPLES # soft_pred_text = rouge_decoding(soft_logp_index, model_params) # SOFTMAX SAMPLES # additional_pred_text = rouge_decoding(sampled_BxT, model_params) # ADDITIONAL SAMPLES # Token-level ROUGE # ROUGE_token = tf.py_function(rouge_token,(outputs["targets"], random_dict["ids"], 0, 0), tf.float32) # CALCULATE ROUGE LOSS: ROUGE score -> ROUGE loss = -ROUGE score # NOTE: for ROUGE variant, change value (0: precision, 1: recall, 2: f1) # rouge_loss_argmax = -tf.py_function(evaluate_rl, (target_text, argmax_pred_text, 2), tf.float32) # rouge_loss_soft = -tf.py_function(evaluate_rl, (target_text, soft_pred_text, 2), tf.float32) # rouge_loss_extra = -tf.py_function(evaluate_rl, (target_text, additional_pred_text, 2), tf.float32) ##### REINFORCE LOSS ########################################################################################## """ Calculate standard REINFORCE loss. Can be document-level (score using RISK trick), or token-level [BxT]. """ # FIND CORRESPONDING LOG_PROBS OF THE I.I.D SAMPLED TOKENS # ARGMAX -> logp(argmax(y)) # argmax_logp = iid_log_probs(argmax_logp_index, batch_index, sequence_index, logp) # SOFTMAX -> logp(sample_y) # softmax_logp = iid_log_probs(soft_logp_index, batch_index, sequence_index, logp) # ADDITIONAL # additional_logp = iid_log_probs(sampled_BxT, batch_index, sequence_index, logp) # CHANGE BELOW IF USING DECODER SAMPLED TOKENS/SCORES # weight the logp by ROUGE score (neg ROUGE_loss), sum values # reinforce_loss = tf.reduce_sum(tf.multiply(rouge_loss_argmax, argmax_logp)) ##### REINFORCE w/ BASELINE ################################################################################### """ Calculate RwB using Socher's loss function (2017). Optional: use a Q_func as baseline. """ # improve the probs of the SOFT labels (soft - hard)*soft_logp # improve the probs of the HARD labels (hard - soft)*hard_logp # BASELINE: CONTROL VARIATE # ffn_output = control_variate(source, targets) # with tf.variable_scope("Q_func"): # cv = rwb_Q_func(tf.reshape(softmax_logp, [1, 32]), tf.reshape(additional_logp, [1, 32])) # cv_loss = tf.reduce_mean(tf.square(tf.subtract(rouge_loss_argmax, cv))) # loss_difference = tf.subtract(rouge_loss_soft, rouge_loss_argmax) # reinforce_baseline = tf.reduce_sum(tf.multiply(loss_difference, softmax_logp)) # BASELINE: HINGE LOSS # rouge_soft = -rouge_loss_soft # rouge_hard = -rouge_loss_argmax # hinge = -tf.maximum((rouge_soft - rouge_hard), 0) # hinge_baseline = tf.reduce_sum(tf.multiply(hinge, softmax_logp)) ##### REINFORCE w/ THRESHOLD ################################################################################## """ Calculate REINFORCE with a constant threshold as the baseline. """ # we take output of ROUGE score as ROUGE_loss = -ROUGE score # intermediate_loss = tf.reduce_sum(tf.multiply(tf.subtract(0.3, -rouge_loss_argmax), argmax_logp)) ##### EXPECTED RISK MINIMISATION ############################################################################## """ Calculate the RISK loss using n sequences from sampling process. """ # L_risk = risk_loss(model_params.batch_size, max_seq_len, # rouge_losses=[rouge_loss_argmax, rouge_loss_soft, rouge_loss_extra], # logps=[topk_dict["logp1"], topk_dict["logp2"], topk_dict["logp3"]], n=3) ##### MIXED LOSS ############################################################################################## """ Implement a mixed loss function that is weighted by an alpha term. """ # combined_loss = tf.math.add(tf.multiply(tf.constant(0.3, dtype=tf.float32), XENT_loss), # tf.multiply(tf.constant(0.7, dtype=tf.float32), L_risk)) # OR conditional loss switch # constraint = tf.random_uniform(shape=(), minval=0, maxval=1, dtype=tf.float32) # combined_loss = tf.cond(constraint > 0.8, lambda: hard_reinforce_loss, lambda: XENT_loss) ##### RELAX CONTROL VARIATE ################################################################################### """ Prepare the target sequence for use in the control variate. """ # z = random_dict["logp_BxTxV"] # z_target, zt_target = create_cv_target(outputs, batch_index, sequence_index, z, z_tilde) ##### RELAX LOSS ############################################################################################## """ Manipulate z and z_tilde using the Q_func to mimic ROUGE loss. """ # with tf.variable_scope("Q_func"): # c_z = Q_func(z, z_target) # with tf.variable_scope("Q_func", reuse=True): # c_z_tilde = Q_func(z_tilde, zt_target) # Formulate RELAX as a loss function # f_y = rouge_loss_soft # negative for loss (defined above) # c_z_tilde1 = tf.stop_gradient(tf.identity(c_z_tilde)) # clone, detach, stop grad # L_relax = tf.reduce_sum(((f_y - c_z_tilde1)*logp_b) - c_z_tilde + c_z) # OR construct gradient estimator # theta = [tv for tv in tf.trainable_variables() if "Q_func" not in tv.name] # d_logp_d_theta = tf.gradients(logp_b, theta)[0] # logp # d_c_z_tilde_d_theta = tf.gradients(c_z_tilde, theta)[0] # d_c_z_d_theta = tf.gradients(c_z, theta)[0] # relax = tf.reduce_sum(f_y - c_z_tilde)*d_logp_d_theta - d_c_z_tilde_d_theta + d_c_z_d_theta # relax = tf.gradients(L_relax, theta)[0] # Calculate the first optimization step with loss # list_of_gradient_variable_pairs = optimizer.compute_gradients(L_relax) # train_op = optimizer.apply_gradients(list_of_gradient_variable_pairs, global_step=global_step) # Variance reduction objective # variance_loss = tf.reduce_mean(tf.square(relax), name="variance_loss") # initialise adafactor again for variance optimiser # var_opt = adafactor.AdafactorOptimizer( # learning_rate=lr, # decay_rate=adafactor.adafactor_decay_rate_pow(0.8), # beta1=0.0) # est_params = [eta, log_temperature] # TODO: REBAR implementation # Adds the parameters of the FFNN # nn_params = [tv for tv in tf.trainable_variables() if "Q_func" in tv.name] # est_params = nn_params # est_params = est_params + nn_params # TODO: REBAR implementation # Additional optimization step # var_gradvars = var_opt.compute_gradients(variance_loss, var_list=est_params) # var_train_op = var_opt.apply_gradients(var_gradvars) # This may allow for both train ops to be passed in the return statement below? # with tf.control_dependencies([train_op, var_train_op]): # train_op = tf.no_op() ############################################################################################################### # Calculate gradients # If freezing layers, only optimise wrt certain layers (find names) - speeds up, worsens performance # last_params = [tv for tv in tf.trainable_variables() if "decoder/LayerNorm/" in tv.name] # list_of_gradient_variable_pairs = optimizer.compute_gradients(combined_loss, var_list=last_params) list_of_gradient_variable_pairs = optimizer.compute_gradients( XENT_loss) train_op = optimizer.apply_gradients( list_of_gradient_variable_pairs, global_step=global_step) tf.logging.set_verbosity(tf.logging.INFO) # Debugging steps - add into logging hook directly if needed # tf.debugging.check_numerics(sum_logp, "DEBUG: sum_logp has a NaN") logging_hook = tf.train.LoggingTensorHook( { "loss": XENT_loss, # "variance_loss": variance_loss, # "cv_loss": cv_loss, "learning_rate": lr, "global_step": global_step, }, every_n_iter=5) # This is the configured estimator function that is returned to train the model return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=XENT_loss, train_op=train_op, training_hooks=[logging_hook], scaffold_fn=_load_vars_from_checkpoint(use_tpu, train_init_checkpoint), host_call=add_scalars_to_summary( model_dir, { "learning_rate": lr, # "rouge_loss_hard": rouge_loss_argmax, # "rouge_loss_soft": rouge_loss_soft, # "rouge_loss_extra": rouge_loss_extra, # "reinforce_loss": reinforce_loss, # "risk_loss": L_risk, # "XENT_loss": XENT_loss, })) # EVALUATION (evaluating the performance) if mode == tf.estimator.ModeKeys.EVAL: eval_metrics = model_params.estimator_eval_metrics_fn( features, outputs) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=XENT_loss, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, config, params): """Estimator model function.""" del labels del config del params tf.get_variable_scope().set_initializer( tf.variance_scaling_initializer(1.0, mode="fan_avg", distribution="uniform")) if mode == tf.estimator.ModeKeys.PREDICT: predictions = model_params.estimator_prediction_fn(features) if include_features_in_predictions: predictions.update(features) if decode_keys: # Decode the raw ids into strings in prediction. def decode_host_call(tensor_dict): for key in decode_keys: predictions[key] = public_parsing_ops.decode( tensor_dict[key], model_params.vocab_filename, model_params.encoder_type) return tensor_dict contrib_tpu.outside_compilation(decode_host_call, predictions) return tpu_estimator.TPUEstimatorSpec(mode=mode, predictions=predictions) training = mode == tf.estimator.ModeKeys.TRAIN if use_tpu and model_params.use_bfloat16: with contrib_tpu.bfloat16_scope(): loss, outputs = model_params.model()(features, training) else: loss, outputs = model_params.model()(features, training) # TPU requires ouputs all have batch dimension and doesn't handle scalar. # Tile all scalars to 1 dimension vector. outputs = _tile_scalar_to_batch_size(outputs, model_params.batch_size) if mode == tf.estimator.ModeKeys.TRAIN: init_lr = model_params.learning_rate global_step = tf.train.get_global_step() lr = init_lr / 0.01 * tf.rsqrt( tf.maximum(tf.to_float(global_step), 10000)) if train_init_checkpoint: lr = tf.minimum( tf.to_float(global_step + 1) / train_warmup_steps * init_lr, lr) optimizer = adafactor.AdafactorOptimizer( learning_rate=lr, decay_rate=adafactor.adafactor_decay_rate_pow(0.8), beta1=0.0) if use_tpu: optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) train_op = optimizer.minimize(loss, global_step=global_step) return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=_load_vars_from_checkpoint(use_tpu, train_init_checkpoint), host_call=add_scalars_to_summary(model_dir, {"learning_rate": lr})) if mode == tf.estimator.ModeKeys.EVAL: eval_metrics = model_params.estimator_eval_metrics_fn( features, outputs) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params=None): """Constructs the object detection model. Args: features: Dictionary of feature tensors, returned from `input_fn`. labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL, otherwise None. mode: Mode key from tf.estimator.ModeKeys. params: Parameter dictionary passed from the estimator. Returns: An `EstimatorSpec` that encapsulates the model and its serving configurations. """ params = params or {} total_loss, train_op, detections, export_outputs = None, None, None, None is_training = mode == tf.estimator.ModeKeys.TRAIN # Make sure to set the Keras learning phase. True during training, # False for inference. tf.keras.backend.set_learning_phase(is_training) # Set policy for mixed-precision training with Keras-based models. if use_tpu and train_config.use_bfloat16: from tensorflow.python.keras.engine import base_layer_utils # pylint: disable=g-import-not-at-top # Enable v2 behavior, as `mixed_bfloat16` is only supported in TF 2.0. base_layer_utils.enable_v2_dtype_behavior() tf.compat.v2.keras.mixed_precision.experimental.set_policy( 'mixed_bfloat16') detection_model = detection_model_fn(is_training=is_training, add_summaries=(not use_tpu)) scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: labels = unstack_batch(labels, unpad_groundtruth_tensors=train_config. unpad_groundtruth_tensors) elif mode == tf.estimator.ModeKeys.EVAL: # For evaling on train data, it is necessary to check whether groundtruth # must be unpadded. boxes_shape = (labels[fields.InputDataFields.groundtruth_boxes]. get_shape().as_list()) unpad_groundtruth_tensors = boxes_shape[ 1] is not None and not use_tpu labels = unstack_batch( labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): provide_groundtruth(detection_model, labels) preprocessed_images = features[fields.InputDataFields.image] side_inputs = detection_model.get_side_inputs(features) if use_tpu and train_config.use_bfloat16: with contrib_tpu.bfloat16_scope(): prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape], **side_inputs) prediction_dict = ops.bfloat16_to_float32_nested( prediction_dict) else: prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape], **side_inputs) def postprocess_wrapper(args): return detection_model.postprocess(args[0], args[1]) if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT): if use_tpu and postprocess_on_cpu: detections = contrib_tpu.outside_compilation( postprocess_wrapper, (prediction_dict, features[fields.InputDataFields.true_image_shape])) else: detections = postprocess_wrapper( (prediction_dict, features[fields.InputDataFields.true_image_shape])) if mode == tf.estimator.ModeKeys.TRAIN: load_pretrained = hparams.load_pretrained if hparams else False if train_config.fine_tune_checkpoint and load_pretrained: if not train_config.fine_tune_checkpoint_type: # train_config.from_detection_checkpoint field is deprecated. For # backward compatibility, set train_config.fine_tune_checkpoint_type # based on train_config.from_detection_checkpoint. if train_config.from_detection_checkpoint: train_config.fine_tune_checkpoint_type = 'detection' else: train_config.fine_tune_checkpoint_type = 'classification' asg_map = detection_model.restore_map( fine_tune_checkpoint_type=train_config. fine_tune_checkpoint_type, load_all_detection_checkpoint_vars=( train_config.load_all_detection_checkpoint_vars)) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( asg_map, train_config.fine_tune_checkpoint, include_global_step=False)) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint( train_config.fine_tune_checkpoint, available_var_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint( train_config.fine_tune_checkpoint, available_var_map) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): if (mode == tf.estimator.ModeKeys.EVAL and eval_config.use_dummy_loss_in_eval): total_loss = tf.constant(1.0) losses_dict = {'Loss/total_loss': total_loss} else: losses_dict = detection_model.loss( prediction_dict, features[fields.InputDataFields.true_image_shape]) losses = [loss_tensor for loss_tensor in losses_dict.values()] if train_config.add_regularization_loss: regularization_losses = detection_model.regularization_losses( ) if use_tpu and train_config.use_bfloat16: regularization_losses = ops.bfloat16_to_float32_nested( regularization_losses) if regularization_losses: regularization_loss = tf.add_n( regularization_losses, name='regularization_loss') losses.append(regularization_loss) losses_dict[ 'Loss/regularization_loss'] = regularization_loss total_loss = tf.add_n(losses, name='total_loss') losses_dict['Loss/total_loss'] = total_loss if 'graph_rewriter_config' in configs: graph_rewriter_fn = graph_rewriter_builder.build( configs['graph_rewriter_config'], is_training=is_training) graph_rewriter_fn() # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we # can write learning rate summaries on TPU without host calls. global_step = tf.train.get_or_create_global_step() training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: training_optimizer = contrib_tpu.CrossShardOptimizer( training_optimizer) # Optionally freeze some layers by setting their gradients to be zero. trainable_variables = None include_variables = (train_config.update_trainable_variables if train_config.update_trainable_variables else None) exclude_variables = (train_config.freeze_variables if train_config.freeze_variables else None) trainable_variables = contrib_framework.filter_variables( tf.trainable_variables(), include_patterns=include_variables, exclude_patterns=exclude_variables) clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: clip_gradients_value = train_config.gradient_clipping_by_norm if not use_tpu: for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var) summaries = [] if use_tpu else None if train_config.summarize_gradients: summaries = [ 'gradients', 'gradient_norm', 'global_gradient_norm' ] train_op = contrib_layers.optimize_loss( loss=total_loss, global_step=global_step, learning_rate=None, clip_gradients=clip_gradients_value, optimizer=training_optimizer, update_ops=detection_model.updates(), variables=trainable_variables, summaries=summaries, name='') # Preventing scope prefix on all variables. if mode == tf.estimator.ModeKeys.PREDICT: exported_output = exporter_lib.add_output_tensor_nodes(detections) export_outputs = { tf.saved_model.signature_constants.PREDICT_METHOD_NAME: tf.estimator.export.PredictOutput(exported_output) } eval_metric_ops = None scaffold = None if mode == tf.estimator.ModeKeys.EVAL: class_agnostic = (fields.DetectionResultFields.detection_classes not in detections) groundtruth = _prepare_groundtruth_for_eval( detection_model, class_agnostic, eval_input_config.max_number_of_boxes) use_original_images = fields.InputDataFields.original_image in features if use_original_images: eval_images = features[fields.InputDataFields.original_image] true_image_shapes = tf.slice( features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3]) original_image_spatial_shapes = features[ fields.InputDataFields.original_image_spatial_shape] else: eval_images = features[fields.InputDataFields.image] true_image_shapes = None original_image_spatial_shapes = None eval_dict = eval_util.result_dict_for_batched_example( eval_images, features[inputs.HASH_KEY], detections, groundtruth, class_agnostic=class_agnostic, scale_to_absolute=True, original_image_spatial_shapes=original_image_spatial_shapes, true_image_shapes=true_image_shapes) if fields.InputDataFields.image_additional_channels in features: eval_dict[fields.InputDataFields. image_additional_channels] = features[ fields.InputDataFields.image_additional_channels] if class_agnostic: category_index = label_map_util.create_class_agnostic_category_index( ) else: category_index = label_map_util.create_category_index_from_labelmap( eval_input_config.label_map_path) vis_metric_ops = None if not use_tpu and use_original_images: eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections( category_index, max_examples_to_draw=eval_config.num_visualizations, max_boxes_to_draw=eval_config.max_num_boxes_to_visualize, min_score_thresh=eval_config.min_score_threshold, use_normalized_coordinates=False) vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops( eval_dict) # Eval metrics on a single example. eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators( eval_config, list(category_index.values()), eval_dict) for loss_key, loss_tensor in iter(losses_dict.items()): eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor) for var in optimizer_summary_vars: eval_metric_ops[var.op.name] = (var, tf.no_op()) if vis_metric_ops is not None: eval_metric_ops.update(vis_metric_ops) eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()} if eval_config.use_moving_averages: variable_averages = tf.train.ExponentialMovingAverage(0.0) variables_to_restore = variable_averages.variables_to_restore() keep_checkpoint_every_n_hours = ( train_config.keep_checkpoint_every_n_hours) saver = tf.train.Saver( variables_to_restore, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours ) scaffold = tf.train.Scaffold(saver=saver) # EVAL executes on CPU, so use regular non-TPU EstimatorSpec. if use_tpu and mode != tf.estimator.ModeKeys.EVAL: return contrib_tpu.TPUEstimatorSpec(mode=mode, scaffold_fn=scaffold_fn, predictions=detections, loss=total_loss, train_op=train_op, eval_metrics=eval_metric_ops, export_outputs=export_outputs) else: if scaffold is None: keep_checkpoint_every_n_hours = ( train_config.keep_checkpoint_every_n_hours) saver = tf.train.Saver( sharded=True, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) scaffold = tf.train.Scaffold(saver=saver) return tf.estimator.EstimatorSpec(mode=mode, predictions=detections, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs, scaffold=scaffold)