示例#1
0
    def eval(self, eval_input_fn):
        """Run distributed eval on Mask RCNN model."""

        output_dir = os.path.join(self._runtime_config.model_dir, 'eval')
        tf.io.gfile.makedirs(output_dir)

        # Summary writer writes out eval metrics.
        run_config = self.build_strategy_configuration('eval')
        eval_params = self.build_model_parameters('eval')
        eval_estimator = self.build_mask_rcnn_estimator(
            eval_params, run_config, 'eval')

        logging.info('Starting to evaluate.')

        last_ckpt = self.get_last_checkpoint_path()

        if last_ckpt is not None:
            logging.info("Restoring parameters from %s\n" % last_ckpt)
            current_step = int(os.path.basename(last_ckpt).split('-')[1])

        else:
            logging.warning(
                "Could not find trained model in model_dir: `%s`, running initialization to predict\n"
                % self._runtime_config.model_dir)
            current_step = 0

        eval_results, predictions = evaluation.evaluate(
            eval_estimator,
            eval_input_fn,
            self._runtime_config.eval_samples,
            self._runtime_config.eval_batch_size,
            self._runtime_config.include_mask,
            self._runtime_config.val_json_file,
            checkpoint_path=last_ckpt)

        self._write_summary(output_dir, eval_results, predictions,
                            current_step)

        if current_step >= self._runtime_config.total_steps:
            logging.info('Evaluation finished after training step %d' %
                         current_step)

        return eval_results
def get_trainable_variables():
    """Get a list of trainable TensorFlow variables.

    Parameters
    ----------
    train_only : boolean
        If True, only get the trainable variables.

    Returns
    -------
    list of Tensor
        A list of trainable TensorFlow variables

    Examples
    --------

    """
    if KERAS_MODELS or LooseVersion(tf.__version__) >= LooseVersion("2.0.0"):
        logging.warning(
            "In TF2.x, only trainable variables created with Keras Models are captured for logging.\n"
            "In TF1.x, if any keras model is defined. Only variables created inside Keras Models will be logged."
        )

        var_list = list()

        for model in KERAS_MODELS:
            var_list.extend(model.trainable_variables)

        # Keep only a list of unique variables (remove potential duplicates)
        var_list = list(set(var_list))

        # clearing the list of Keras Model to avoid memory leaks
        KERAS_MODELS.clear()

        return [var for var in sorted(var_list, key=lambda v: v.name)]

    else:
        # return tf.trainable_variables()  # deprecated in TF2.x
        from tensorflow.python.keras.backend import get_graph
        return get_graph().get_collection('trainable_variables')
def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False):
    """Creates an operation to assign specific variables from a checkpoint.
    Args:
    model_path: The full path to the model checkpoint. To get latest checkpoint
      use `model_path = tf.train.latest_checkpoint(checkpoint_dir)`
    var_list: A list of (possibly partitioned) `Variable` objects or a
      dictionary mapping names in the checkpoint to the corresponding variables
      or list of variables to initialize from that checkpoint value. For
      partitioned Variables, the name in the checkpoint must be the full
      variable, not the name of the partitioned variable, eg. "my_var" rather
      than "my_var/part_4". If empty, returns no_op(), {}.
    ignore_missing_vars: Boolean, if True ignore variables missing in the
      checkpoint with a warning instead of failing.
    Returns:
    the restore_op and the feed_dict that need to be run to restore var_list.
    Raises:
    ValueError: If `ignore_missing_vars` is False and the checkpoint specified
        at `model_path` is missing one of the variables in `var_list`.
  """
    # Normalize var_list into a dictionary mapping names in the
    # checkpoint to the list of variables to initialize from that
    # checkpoint variable. Sliced (including partitioned) variables will
    # end up under the same key.
    grouped_vars = {}
    if isinstance(var_list, (tuple, list)):
        for var in var_list:
            ckpt_name = get_variable_full_name(var)
            if ckpt_name not in grouped_vars:
                grouped_vars[ckpt_name] = []
            grouped_vars[ckpt_name].append(var)

    else:
        for ckpt_name, value in var_list.items():
            if isinstance(value, (tuple, list)):
                grouped_vars[ckpt_name] = value
            else:
                grouped_vars[ckpt_name] = [value]

    # Read each checkpoint entry. Create a placeholder variable and
    # add the (possibly sliced) data from the checkpoint to the feed_dict.
    reader = tf.compat.v1.train.NewCheckpointReader(model_path)
    feed_dict = {}
    assign_ops = []
    for ckpt_name in grouped_vars:
        if not reader.has_tensor(ckpt_name):
            log_str = 'Checkpoint is missing variable [%s]' % ckpt_name
            if ignore_missing_vars:
                logging.warning(log_str)
                continue
            else:
                raise ValueError(log_str)
        ckpt_value = reader.get_tensor(ckpt_name)

        for var in grouped_vars[ckpt_name]:
            placeholder_tensor = tf.compat.v1.placeholder(
                dtype=var.dtype.base_dtype,
                shape=var.get_shape(),
                name='placeholder/' + var.op.name
            )

            assign_ops.append(var.assign(placeholder_tensor))

            if not var._save_slice_info:
                if var.get_shape() != ckpt_value.shape:
                    raise ValueError(
                        'Total size of new array must be unchanged for %s '
                        'lh_shape: [%s], rh_shape: [%s]' %
                        (ckpt_name, str(ckpt_value.shape), str(var.get_shape())))

                feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape)

            else:
                slice_dims = zip(var._save_slice_info.var_offset,
                                 var._save_slice_info.var_shape)

                slice_dims = [(start, start + size) for (start, size) in slice_dims]
                slice_dims = [slice(*x) for x in slice_dims]

                slice_value = ckpt_value[slice_dims]
                slice_value = slice_value.reshape(var._save_slice_info.var_shape)

                feed_dict[placeholder_tensor] = slice_value

    print_op = tf.print(
        "[GPU %02d] Restoring pretrained weights (%d Tensors) from: %s" % (
            MPI_rank(),
            len(assign_ops),
            model_path
        ),
        output_stream=sys.stdout
    )

    with tf.control_dependencies([print_op]):
        assign_op = tf.group(*assign_ops)

    return assign_op, feed_dict
    def after_run(self, run_context, run_values):  # pylint: disable=unused-argument
        """Called after each call to run().
        The `run_values` argument contains results of requested ops/tensors by
        `before_run()`.
        The `run_context` argument is the same one send to `before_run` call.
        `run_context.request_stop()` can be called to stop the iteration.
        If `session.run()` raises any exceptions then `after_run()` is not called.
        Args:
          run_context: A `SessionRunContext` object.
          run_values: A SessionRunValues object.
        """

        batch_time = time.time() - self._step_t0

        _global_step = run_values.results["global_step"]

        if self._is_training and self._AMP_steps_since_last_loss_scale is not None:

            try:
                AMP_steps_since_last_loss_scale = run_values.results["AMP"][
                    "steps_since_last_loss_scale"]
                AMP_loss_scale = run_values.results["AMP"][
                    "current_loss_scale"]

            except KeyError:
                AMP_steps_since_last_loss_scale = None
                AMP_loss_scale = None

            if AMP_steps_since_last_loss_scale is not None:

                # Step has been skipped
                if _global_step != (self._amp_steps_non_skipped + 1):
                    logging.warning(
                        "AMP - Training iteration `#{step}` has been skipped and loss rescaled. "
                        "New Loss Scale: {loss_scale}\n".format(
                            step=self._current_step,
                            loss_scale=AMP_loss_scale))

                else:
                    self._amp_steps_non_skipped += 1

                    if AMP_steps_since_last_loss_scale == 0:
                        logging.warning(
                            "AMP - Training iteration `#{step}` - Loss scale has been automatically increased. "
                            "New Loss Scale: {loss_scale}\n".format(
                                step=self._current_step,
                                loss_scale=AMP_loss_scale))

        else:
            AMP_steps_since_last_loss_scale = None
            AMP_loss_scale = None

        def get_model_throughput():
            gpu_batch_size = run_values.results["batch_size"]
            return gpu_batch_size / batch_time * self._n_gpus

        # def get_model_stats():
        #     return get_tf_model_statistics(batch_size=run_values.results["batch_size"], scope_name=None)
        #
        # if self._model_stats is None:
        #     self._model_stats = get_model_stats()

        is_log_step = self._current_step % self._log_every_n_steps == 0

        if is_log_step:

            if self._current_step > self._warmup_steps:

                try:
                    model_throughput = self._model_throughput.read()
                except ValueError:
                    model_throughput = get_model_throughput()

            else:
                model_throughput = get_model_throughput()

            self._logging_proxy.log_step(iteration=self._current_step,
                                         throughput=model_throughput,
                                         gpu_stats=[])

            self._logging_proxy.log_amp_runtime(
                current_loss_scale=AMP_loss_scale,
                steps_non_skipped=_global_step,
                steps_since_last_scale=AMP_steps_since_last_loss_scale,
            )

            metric_data = dict()

            for name, value in sorted(run_values.results["metrics"].items(),
                                      key=operator.itemgetter(0)):
                self._metrics[name]["aggregator"].record(value)

                metric_data[name] = self._metrics[name]["aggregator"].read()

            self._logging_proxy.log_metrics(metric_data=metric_data,
                                            iteration=self._current_step,
                                            runtime_mode=self._runtime_mode)

            print()  # Visual Spacing

        elif self._current_step > self._warmup_steps:
            # Do not store speed for log step due to additional fetches
            self._model_throughput.record(get_model_throughput())
示例#5
0
 def log_warning(self, message):
     logging.warning("%s%s" % (self.LOGGING_PREFIX, message))