def eval(self, eval_input_fn): """Run distributed eval on Mask RCNN model.""" output_dir = os.path.join(self._runtime_config.model_dir, 'eval') tf.io.gfile.makedirs(output_dir) # Summary writer writes out eval metrics. run_config = self.build_strategy_configuration('eval') eval_params = self.build_model_parameters('eval') eval_estimator = self.build_mask_rcnn_estimator( eval_params, run_config, 'eval') logging.info('Starting to evaluate.') last_ckpt = self.get_last_checkpoint_path() if last_ckpt is not None: logging.info("Restoring parameters from %s\n" % last_ckpt) current_step = int(os.path.basename(last_ckpt).split('-')[1]) else: logging.warning( "Could not find trained model in model_dir: `%s`, running initialization to predict\n" % self._runtime_config.model_dir) current_step = 0 eval_results, predictions = evaluation.evaluate( eval_estimator, eval_input_fn, self._runtime_config.eval_samples, self._runtime_config.eval_batch_size, self._runtime_config.include_mask, self._runtime_config.val_json_file, checkpoint_path=last_ckpt) self._write_summary(output_dir, eval_results, predictions, current_step) if current_step >= self._runtime_config.total_steps: logging.info('Evaluation finished after training step %d' % current_step) return eval_results
def get_trainable_variables(): """Get a list of trainable TensorFlow variables. Parameters ---------- train_only : boolean If True, only get the trainable variables. Returns ------- list of Tensor A list of trainable TensorFlow variables Examples -------- """ if KERAS_MODELS or LooseVersion(tf.__version__) >= LooseVersion("2.0.0"): logging.warning( "In TF2.x, only trainable variables created with Keras Models are captured for logging.\n" "In TF1.x, if any keras model is defined. Only variables created inside Keras Models will be logged." ) var_list = list() for model in KERAS_MODELS: var_list.extend(model.trainable_variables) # Keep only a list of unique variables (remove potential duplicates) var_list = list(set(var_list)) # clearing the list of Keras Model to avoid memory leaks KERAS_MODELS.clear() return [var for var in sorted(var_list, key=lambda v: v.name)] else: # return tf.trainable_variables() # deprecated in TF2.x from tensorflow.python.keras.backend import get_graph return get_graph().get_collection('trainable_variables')
def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False): """Creates an operation to assign specific variables from a checkpoint. Args: model_path: The full path to the model checkpoint. To get latest checkpoint use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` var_list: A list of (possibly partitioned) `Variable` objects or a dictionary mapping names in the checkpoint to the corresponding variables or list of variables to initialize from that checkpoint value. For partitioned Variables, the name in the checkpoint must be the full variable, not the name of the partitioned variable, eg. "my_var" rather than "my_var/part_4". If empty, returns no_op(), {}. ignore_missing_vars: Boolean, if True ignore variables missing in the checkpoint with a warning instead of failing. Returns: the restore_op and the feed_dict that need to be run to restore var_list. Raises: ValueError: If `ignore_missing_vars` is False and the checkpoint specified at `model_path` is missing one of the variables in `var_list`. """ # Normalize var_list into a dictionary mapping names in the # checkpoint to the list of variables to initialize from that # checkpoint variable. Sliced (including partitioned) variables will # end up under the same key. grouped_vars = {} if isinstance(var_list, (tuple, list)): for var in var_list: ckpt_name = get_variable_full_name(var) if ckpt_name not in grouped_vars: grouped_vars[ckpt_name] = [] grouped_vars[ckpt_name].append(var) else: for ckpt_name, value in var_list.items(): if isinstance(value, (tuple, list)): grouped_vars[ckpt_name] = value else: grouped_vars[ckpt_name] = [value] # Read each checkpoint entry. Create a placeholder variable and # add the (possibly sliced) data from the checkpoint to the feed_dict. reader = tf.compat.v1.train.NewCheckpointReader(model_path) feed_dict = {} assign_ops = [] for ckpt_name in grouped_vars: if not reader.has_tensor(ckpt_name): log_str = 'Checkpoint is missing variable [%s]' % ckpt_name if ignore_missing_vars: logging.warning(log_str) continue else: raise ValueError(log_str) ckpt_value = reader.get_tensor(ckpt_name) for var in grouped_vars[ckpt_name]: placeholder_tensor = tf.compat.v1.placeholder( dtype=var.dtype.base_dtype, shape=var.get_shape(), name='placeholder/' + var.op.name ) assign_ops.append(var.assign(placeholder_tensor)) if not var._save_slice_info: if var.get_shape() != ckpt_value.shape: raise ValueError( 'Total size of new array must be unchanged for %s ' 'lh_shape: [%s], rh_shape: [%s]' % (ckpt_name, str(ckpt_value.shape), str(var.get_shape()))) feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape) else: slice_dims = zip(var._save_slice_info.var_offset, var._save_slice_info.var_shape) slice_dims = [(start, start + size) for (start, size) in slice_dims] slice_dims = [slice(*x) for x in slice_dims] slice_value = ckpt_value[slice_dims] slice_value = slice_value.reshape(var._save_slice_info.var_shape) feed_dict[placeholder_tensor] = slice_value print_op = tf.print( "[GPU %02d] Restoring pretrained weights (%d Tensors) from: %s" % ( MPI_rank(), len(assign_ops), model_path ), output_stream=sys.stdout ) with tf.control_dependencies([print_op]): assign_op = tf.group(*assign_ops) return assign_op, feed_dict
def after_run(self, run_context, run_values): # pylint: disable=unused-argument """Called after each call to run(). The `run_values` argument contains results of requested ops/tensors by `before_run()`. The `run_context` argument is the same one send to `before_run` call. `run_context.request_stop()` can be called to stop the iteration. If `session.run()` raises any exceptions then `after_run()` is not called. Args: run_context: A `SessionRunContext` object. run_values: A SessionRunValues object. """ batch_time = time.time() - self._step_t0 _global_step = run_values.results["global_step"] if self._is_training and self._AMP_steps_since_last_loss_scale is not None: try: AMP_steps_since_last_loss_scale = run_values.results["AMP"][ "steps_since_last_loss_scale"] AMP_loss_scale = run_values.results["AMP"][ "current_loss_scale"] except KeyError: AMP_steps_since_last_loss_scale = None AMP_loss_scale = None if AMP_steps_since_last_loss_scale is not None: # Step has been skipped if _global_step != (self._amp_steps_non_skipped + 1): logging.warning( "AMP - Training iteration `#{step}` has been skipped and loss rescaled. " "New Loss Scale: {loss_scale}\n".format( step=self._current_step, loss_scale=AMP_loss_scale)) else: self._amp_steps_non_skipped += 1 if AMP_steps_since_last_loss_scale == 0: logging.warning( "AMP - Training iteration `#{step}` - Loss scale has been automatically increased. " "New Loss Scale: {loss_scale}\n".format( step=self._current_step, loss_scale=AMP_loss_scale)) else: AMP_steps_since_last_loss_scale = None AMP_loss_scale = None def get_model_throughput(): gpu_batch_size = run_values.results["batch_size"] return gpu_batch_size / batch_time * self._n_gpus # def get_model_stats(): # return get_tf_model_statistics(batch_size=run_values.results["batch_size"], scope_name=None) # # if self._model_stats is None: # self._model_stats = get_model_stats() is_log_step = self._current_step % self._log_every_n_steps == 0 if is_log_step: if self._current_step > self._warmup_steps: try: model_throughput = self._model_throughput.read() except ValueError: model_throughput = get_model_throughput() else: model_throughput = get_model_throughput() self._logging_proxy.log_step(iteration=self._current_step, throughput=model_throughput, gpu_stats=[]) self._logging_proxy.log_amp_runtime( current_loss_scale=AMP_loss_scale, steps_non_skipped=_global_step, steps_since_last_scale=AMP_steps_since_last_loss_scale, ) metric_data = dict() for name, value in sorted(run_values.results["metrics"].items(), key=operator.itemgetter(0)): self._metrics[name]["aggregator"].record(value) metric_data[name] = self._metrics[name]["aggregator"].read() self._logging_proxy.log_metrics(metric_data=metric_data, iteration=self._current_step, runtime_mode=self._runtime_mode) print() # Visual Spacing elif self._current_step > self._warmup_steps: # Do not store speed for log step due to additional fetches self._model_throughput.record(get_model_throughput())
def log_warning(self, message): logging.warning("%s%s" % (self.LOGGING_PREFIX, message))