示例#1
0
    def __init__(
        self,
        args,
        dataset_split_name: str,
        is_training: bool,
        name: str,
    ):
        self.name = name
        self.args = args
        self.dataset_split_name = dataset_split_name
        self.is_training = is_training

        self.shuffle = self.args.shuffle

        self.log = utils.get_logger(self.name)
        self.timer = utils.Timer(self.log)
        self.dataset_path = Path(self.args.dataset_path)
        self.dataset_path_with_split_name = self.dataset_path / self.dataset_split_name

        with utils.format_text("yellow", ["underline"]) as fmt:
            self.log.info(self.name)
            self.log.info(
                fmt(f"dataset_path_with_split_name: {self.dataset_path_with_split_name}"
                    ))
            self.log.info(
                fmt(f"dataset_split_name: {self.dataset_split_name}"))
示例#2
0
    def __init__(
        self,
        args,
        dataset_split_name: str,
        is_training: bool,
        name: str,
    ):
        self.name = name
        self.args = args
        self.dataset_split_name = dataset_split_name
        self.is_training = is_training

        # args.inference is False by default
        self.need_label = not self.args.inference
        self.shuffle = self.args.shuffle
        self.supported_extensions = [".jpg", ".JPEG", ".png"]

        self.log = utils.get_logger(self.name, None)
        self.timer = utils.Timer(self.log)
        self.dataset_path = Path(self.args.dataset_path)
        self.dataset_path_with_split_name = self.dataset_path / self.dataset_split_name

        with utils.format_text("yellow", ["underline"]) as fmt:
            self.log.info(self.name)
            self.log.info(
                fmt(f"dataset_path_with_split_name: {self.dataset_path_with_split_name}"
                    ))
            self.log.info(
                fmt(f"dataset_split_name: {self.dataset_split_name}"))
示例#3
0
    def inference(self, checkpoint_path):
        assert not self.args.shuffle, "Current implementation of `inference` requires non-shuffled dataset"
        _data_num = self.dataset.num_samples
        if _data_num % self.args.batch_size != 0:
            with format_text("red", attrs=["bold"]) as fmt:
                self.log.warning(
                    fmt(f"Among {_data_num} data, last {_data_num%self.dataset.batch_size} items will not"
                        f" be processed during inferential procedure."))

        if self.args.inference_output not in self._available_inference_output:
            raise ValueError(
                f"Inappropriate inference_output type for "
                f"{self.__class__.__name__}: {self.args.inference_output}.\n"
                f"Available outputs are {self._available_inference_output}")

        self.log.info("Inference started")
        self.setup_dataset_iterator()
        self.ckpt_loader.load(checkpoint_path)

        step = self.build_evaluation_step(checkpoint_path)
        checkpoint_glob, checkpoint_path = self.build_checkpoint_paths(
            checkpoint_path)
        self.session.run(tf.local_variables_initializer())

        eval_dict = self.run_inference(step, is_training=False, do_eval=False)

        self.save_inference_result(eval_dict, checkpoint_path)
示例#4
0
文件: trainer.py 项目: zbxzc35/MMNet
    def log_step_message(self,
                         header,
                         losses,
                         speeds,
                         comparative_loss,
                         batch_size,
                         is_training,
                         tag=""):
        def get_loss_color(old_loss: float, new_loss: float):
            if old_loss < new_loss:
                return "red"
            else:
                return "green"

        def get_log_color(is_training: bool):
            if is_training:
                return {"color": "blue", "attrs": ["bold"]}
            else:
                return {"color": "yellow", "attrs": ["underline"]}

        self.last_loss.setdefault(tag, comparative_loss)
        loss_color = get_loss_color(self.last_loss.get(tag, 0),
                                    comparative_loss)
        self.last_loss[tag] = comparative_loss

        model_size = hf.format_size(self.model.total_params * 4)
        total_params = hf.format_number(self.model.total_params)

        loss_desc, loss_val = self.build_info_step_message(losses, "{:7.4f}")
        header_desc, header_val = self.build_duration_step_message(header)
        speed_desc, speed_val = self.build_info_step_message(speeds, "{:4.0f}")

        with utils.format_text(loss_color) as fmt:
            loss_val_colored = fmt(loss_val)
            msg = (
                f"[{tag}] {header_desc}: {header_val}\t"
                f"{speed_desc}: {speed_val} ({self.args.width},{self.args.height};{batch_size})\t"
                f"{loss_desc}: {loss_val_colored} "
                f"| {model_size} {total_params}")

            with utils.format_text(**get_log_color(is_training)) as fmt:
                self.log.info(fmt(msg))
示例#5
0
 def build_iters_from_batch_size(self, num_samples, batch_size):
     iters = self.dataset.num_samples // self.args.batch_size
     num_ignored_samples = self.dataset.num_samples % self.args.batch_size
     if num_ignored_samples > 0:
         with format_text("red", attrs=["bold"]) as fmt:
             msg = (
                 f"Number of samples cannot be divided by batch_size, "
                 f"so it ignores some data examples in evaluation: "
                 f"{self.dataset.num_samples} % {self.args.batch_size} = {num_ignored_samples}"
             )
             self.log.warning(fmt(msg))
     return iters
示例#6
0
文件: base.py 项目: xstarse/TC-ResNet
    def log_metrics(self, step):
        """
        Logging metrics that are evaluated.
        """
        assert step == self.eval_metric_aggregator.step, \
            (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}"
             f"`evaluate` function should be called before calling this function")

        log_dicts = dict()
        log_dicts.update(self.eval_metric_aggregator.get_logs())

        with format_text("green", ["bold"]) as fmt:
            for metric_key, log_str in log_dicts.items():
                self.log.info(fmt(log_str))
示例#7
0
    def count_samples(
        self,
        samples: List,
    ) -> int:
        """Count number of samples in dataset.

        Args:
            samples: List of samples (e.g. filenames, labels).

        Returns:
            Number of samples.
        """
        num_samples = len(samples)
        with utils.format_text("yellow", ["underline"]) as fmt:
            self.log.info(fmt(f"number of data: {num_samples}"))

        return num_samples
示例#8
0
    def load(self, checkpoint_stempath):
        def get_variable_full_name(var):
            if var._save_slice_info:
                return var._save_slice_info.full_name
            else:
                return var.op.name

        if not self.has_previous_info:
            if isinstance(self.variables_to_restore, (tuple, list)):
                for var in self.variables_to_restore:
                    ckpt_name = get_variable_full_name(var)
                    if ckpt_name not in self.grouped_vars:
                        self.grouped_vars[ckpt_name] = []
                    self.grouped_vars[ckpt_name].append(var)

            else:
                for ckpt_name, value in self.variables_to_restore.items():
                    if isinstance(value, (tuple, list)):
                        self.grouped_vars[ckpt_name] = value
                    else:
                        self.grouped_vars[ckpt_name] = [value]

        # Read each checkpoint entry. Create a placeholder variable and
        # add the (possibly sliced) data from the checkpoint to the feed_dict.
        reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_stempath)
        feed_dict = {}
        assign_ops = []
        for ckpt_name in self.grouped_vars:
            if not reader.has_tensor(ckpt_name):
                log_str = f"Checkpoint is missing variable [{ckpt_name}]"
                if self.ignore_missing_vars:
                    self.logger.warning(log_str)
                    continue
                else:
                    raise ValueError(log_str)
            ckpt_value = reader.get_tensor(ckpt_name)

            for var in self.grouped_vars[ckpt_name]:
                placeholder_name = f"placeholder/{var.op.name}"
                if self.has_previous_info:
                    placeholder_tensor = self.placeholders[placeholder_name]
                else:
                    placeholder_tensor = tf.placeholder(
                        dtype=var.dtype.base_dtype,
                        shape=var.get_shape(),
                        name=placeholder_name)
                    assign_ops.append(var.assign(placeholder_tensor))
                    self.placeholders[placeholder_name] = placeholder_tensor

                if not var._save_slice_info:
                    if var.get_shape() != ckpt_value.shape:
                        raise ValueError(
                            f"Total size of new array must be unchanged for {ckpt_name} "
                            f"lh_shape: [{str(ckpt_value.shape)}], rh_shape: [{str(var.get_shape())}]"
                        )

                    feed_dict[placeholder_tensor] = ckpt_value.reshape(
                        ckpt_value.shape)
                else:
                    slice_dims = zip(var._save_slice_info.var_offset,
                                     var._save_slice_info.var_shape)
                    slice_dims = [(start, start + size)
                                  for (start, size) in slice_dims]
                    slice_dims = [slice(*x) for x in slice_dims]
                    slice_value = ckpt_value[slice_dims]
                    slice_value = slice_value.reshape(
                        var._save_slice_info.var_shape)
                    feed_dict[placeholder_tensor] = slice_value

        if not self.has_previous_info:
            self.assign_op = control_flow_ops.group(*assign_ops)

        self.session.run(self.assign_op, feed_dict)

        if len(feed_dict) > 0:
            for key in feed_dict.keys():
                self.logger.info(f"init from checkpoint > {key}")
        else:
            self.logger.info(f"No init from checkpoint")

        with format_text("cyan", attrs=["bold", "underline"]) as fmt:
            self.logger.info(fmt(f"Restore from {checkpoint_stempath}"))
        self.has_previous_info = True