def __init__( self, args, dataset_split_name: str, is_training: bool, name: str, ): self.name = name self.args = args self.dataset_split_name = dataset_split_name self.is_training = is_training self.shuffle = self.args.shuffle self.log = utils.get_logger(self.name) self.timer = utils.Timer(self.log) self.dataset_path = Path(self.args.dataset_path) self.dataset_path_with_split_name = self.dataset_path / self.dataset_split_name with utils.format_text("yellow", ["underline"]) as fmt: self.log.info(self.name) self.log.info( fmt(f"dataset_path_with_split_name: {self.dataset_path_with_split_name}" )) self.log.info( fmt(f"dataset_split_name: {self.dataset_split_name}"))
def __init__( self, args, dataset_split_name: str, is_training: bool, name: str, ): self.name = name self.args = args self.dataset_split_name = dataset_split_name self.is_training = is_training # args.inference is False by default self.need_label = not self.args.inference self.shuffle = self.args.shuffle self.supported_extensions = [".jpg", ".JPEG", ".png"] self.log = utils.get_logger(self.name, None) self.timer = utils.Timer(self.log) self.dataset_path = Path(self.args.dataset_path) self.dataset_path_with_split_name = self.dataset_path / self.dataset_split_name with utils.format_text("yellow", ["underline"]) as fmt: self.log.info(self.name) self.log.info( fmt(f"dataset_path_with_split_name: {self.dataset_path_with_split_name}" )) self.log.info( fmt(f"dataset_split_name: {self.dataset_split_name}"))
def inference(self, checkpoint_path): assert not self.args.shuffle, "Current implementation of `inference` requires non-shuffled dataset" _data_num = self.dataset.num_samples if _data_num % self.args.batch_size != 0: with format_text("red", attrs=["bold"]) as fmt: self.log.warning( fmt(f"Among {_data_num} data, last {_data_num%self.dataset.batch_size} items will not" f" be processed during inferential procedure.")) if self.args.inference_output not in self._available_inference_output: raise ValueError( f"Inappropriate inference_output type for " f"{self.__class__.__name__}: {self.args.inference_output}.\n" f"Available outputs are {self._available_inference_output}") self.log.info("Inference started") self.setup_dataset_iterator() self.ckpt_loader.load(checkpoint_path) step = self.build_evaluation_step(checkpoint_path) checkpoint_glob, checkpoint_path = self.build_checkpoint_paths( checkpoint_path) self.session.run(tf.local_variables_initializer()) eval_dict = self.run_inference(step, is_training=False, do_eval=False) self.save_inference_result(eval_dict, checkpoint_path)
def log_step_message(self, header, losses, speeds, comparative_loss, batch_size, is_training, tag=""): def get_loss_color(old_loss: float, new_loss: float): if old_loss < new_loss: return "red" else: return "green" def get_log_color(is_training: bool): if is_training: return {"color": "blue", "attrs": ["bold"]} else: return {"color": "yellow", "attrs": ["underline"]} self.last_loss.setdefault(tag, comparative_loss) loss_color = get_loss_color(self.last_loss.get(tag, 0), comparative_loss) self.last_loss[tag] = comparative_loss model_size = hf.format_size(self.model.total_params * 4) total_params = hf.format_number(self.model.total_params) loss_desc, loss_val = self.build_info_step_message(losses, "{:7.4f}") header_desc, header_val = self.build_duration_step_message(header) speed_desc, speed_val = self.build_info_step_message(speeds, "{:4.0f}") with utils.format_text(loss_color) as fmt: loss_val_colored = fmt(loss_val) msg = ( f"[{tag}] {header_desc}: {header_val}\t" f"{speed_desc}: {speed_val} ({self.args.width},{self.args.height};{batch_size})\t" f"{loss_desc}: {loss_val_colored} " f"| {model_size} {total_params}") with utils.format_text(**get_log_color(is_training)) as fmt: self.log.info(fmt(msg))
def build_iters_from_batch_size(self, num_samples, batch_size): iters = self.dataset.num_samples // self.args.batch_size num_ignored_samples = self.dataset.num_samples % self.args.batch_size if num_ignored_samples > 0: with format_text("red", attrs=["bold"]) as fmt: msg = ( f"Number of samples cannot be divided by batch_size, " f"so it ignores some data examples in evaluation: " f"{self.dataset.num_samples} % {self.args.batch_size} = {num_ignored_samples}" ) self.log.warning(fmt(msg)) return iters
def log_metrics(self, step): """ Logging metrics that are evaluated. """ assert step == self.eval_metric_aggregator.step, \ (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}" f"`evaluate` function should be called before calling this function") log_dicts = dict() log_dicts.update(self.eval_metric_aggregator.get_logs()) with format_text("green", ["bold"]) as fmt: for metric_key, log_str in log_dicts.items(): self.log.info(fmt(log_str))
def count_samples( self, samples: List, ) -> int: """Count number of samples in dataset. Args: samples: List of samples (e.g. filenames, labels). Returns: Number of samples. """ num_samples = len(samples) with utils.format_text("yellow", ["underline"]) as fmt: self.log.info(fmt(f"number of data: {num_samples}")) return num_samples
def load(self, checkpoint_stempath): def get_variable_full_name(var): if var._save_slice_info: return var._save_slice_info.full_name else: return var.op.name if not self.has_previous_info: if isinstance(self.variables_to_restore, (tuple, list)): for var in self.variables_to_restore: ckpt_name = get_variable_full_name(var) if ckpt_name not in self.grouped_vars: self.grouped_vars[ckpt_name] = [] self.grouped_vars[ckpt_name].append(var) else: for ckpt_name, value in self.variables_to_restore.items(): if isinstance(value, (tuple, list)): self.grouped_vars[ckpt_name] = value else: self.grouped_vars[ckpt_name] = [value] # Read each checkpoint entry. Create a placeholder variable and # add the (possibly sliced) data from the checkpoint to the feed_dict. reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_stempath) feed_dict = {} assign_ops = [] for ckpt_name in self.grouped_vars: if not reader.has_tensor(ckpt_name): log_str = f"Checkpoint is missing variable [{ckpt_name}]" if self.ignore_missing_vars: self.logger.warning(log_str) continue else: raise ValueError(log_str) ckpt_value = reader.get_tensor(ckpt_name) for var in self.grouped_vars[ckpt_name]: placeholder_name = f"placeholder/{var.op.name}" if self.has_previous_info: placeholder_tensor = self.placeholders[placeholder_name] else: placeholder_tensor = tf.placeholder( dtype=var.dtype.base_dtype, shape=var.get_shape(), name=placeholder_name) assign_ops.append(var.assign(placeholder_tensor)) self.placeholders[placeholder_name] = placeholder_tensor if not var._save_slice_info: if var.get_shape() != ckpt_value.shape: raise ValueError( f"Total size of new array must be unchanged for {ckpt_name} " f"lh_shape: [{str(ckpt_value.shape)}], rh_shape: [{str(var.get_shape())}]" ) feed_dict[placeholder_tensor] = ckpt_value.reshape( ckpt_value.shape) else: slice_dims = zip(var._save_slice_info.var_offset, var._save_slice_info.var_shape) slice_dims = [(start, start + size) for (start, size) in slice_dims] slice_dims = [slice(*x) for x in slice_dims] slice_value = ckpt_value[slice_dims] slice_value = slice_value.reshape( var._save_slice_info.var_shape) feed_dict[placeholder_tensor] = slice_value if not self.has_previous_info: self.assign_op = control_flow_ops.group(*assign_ops) self.session.run(self.assign_op, feed_dict) if len(feed_dict) > 0: for key in feed_dict.keys(): self.logger.info(f"init from checkpoint > {key}") else: self.logger.info(f"No init from checkpoint") with format_text("cyan", attrs=["bold", "underline"]) as fmt: self.logger.info(fmt(f"Restore from {checkpoint_stempath}")) self.has_previous_info = True