def get_scaffold(self, mode, global_step=None, iter_initializer=None): """Get training scaffold.""" init_op = tf.global_variables_initializer() if iter_initializer is None: local_init_op = tf.tables_initializer() else: local_init_op = tf.group(tf.tables_initializer(), iter_initializer) saver = self.get_saver(global_step) scaffold = tf.train.Scaffold( saver=saver, init_op=init_op, local_init_op=local_init_op) return scaffold
def infer(self, **kwargs): # pylint: disable=arguments-differ, unused-argument """Make a inference.""" inputs = self.build_inputs(utils.INFER) self.build() self.session.run(tf.global_variables_initializer()) self.session.run(tf.tables_initializer()) self.session.run(inputs.iterator.initializer) infer_data_size = self.config['data']['infer_data_size'] batch_size = self.config['data']['task']['batch_size'] steps = int(math.ceil(infer_data_size / batch_size)) weights_ckpt_dir = tf.train.latest_checkpoint(self.checkpoint_dir) self.model.load_weights(weights_ckpt_dir) logits = self.model.predict(inputs.input_x_dict, steps=steps) preds = np.argmax(logits, axis=-1) save_infer_res(self.config, logits, preds)
def export_model(self): """Export a model to tensorflow SavedModel.""" inputs = self.build_inputs(utils.INFER) self.build() logits = self.model(inputs.input_x_dict) score = tf.nn.softmax(logits) self.session.run(tf.global_variables_initializer()) self.session.run(tf.tables_initializer()) self.session.run(inputs.iterator.initializer) weights_ckpt_dir = tf.train.latest_checkpoint(self.checkpoint_dir) self.model.load_weights(weights_ckpt_dir) output_dict = {"score": score} to_saved_model(self.config, self.session, inputs.input_x_dict, output_dict)
def eval(self): """Evaluate the model.""" inputs = self.build_inputs(utils.EVAL) self.build() self.session.run(tf.global_variables_initializer()) self.session.run(tf.tables_initializer()) self.session.run(inputs.iterator.initializer) eval_data_size = self.config['data']['eval_data_size'] batch_size = self.config['data']['task']['batch_size'] steps = int(math.ceil(eval_data_size / batch_size)) weights_ckpt_dir = tf.train.latest_checkpoint(self.checkpoint_dir) self.model.load_weights(weights_ckpt_dir) results = self.model.evaluate(inputs.input_x_dict, inputs.input_y_dict["input_y"], steps=steps) for metric, res in zip(self.model.metrics_names, results): print("{}: {}".format(metric, res))
def train_core(self, train_inputs, eval_inputs=None): """Core part of training.""" self.build() self.session.run(tf.global_variables_initializer()) self.session.run(tf.tables_initializer()) self.session.run(train_inputs.iterator.initializer) if eval_inputs is not None: self.session.run(eval_inputs.iterator.initializer) validation_data = (eval_inputs.input_x_dict, eval_inputs.input_y_dict["input_y"]) eval_data_size = self.config['data']['eval_data_size'] batch_size = self.config['data']['task']['batch_size'] validation_steps = int(eval_data_size / batch_size) else: validation_data = None validation_steps = None train_data_size = self.config['data']['train_data_size'] num_epochs = self.config['solver']['optimizer']['epochs'] batch_size = self.config['data']['task']['batch_size'] num_batch_per_epoch = int(math.ceil(train_data_size / batch_size)) callbacks = [ tf.keras.callbacks.TensorBoard(os.path.join( self.model_path, "logs"), histogram_freq=0, write_graph=True, write_grads=True, write_images=True), tf.keras.callbacks.ModelCheckpoint(os.path.join( self.checkpoint_dir, "weights.{epoch:02d}"), save_weights_only=True, save_best_only=True) ] self.model.fit(train_inputs.input_x_dict, train_inputs.input_y_dict["input_y"], callbacks=callbacks, epochs=num_epochs, steps_per_epoch=num_batch_per_epoch, validation_data=validation_data, validation_steps=validation_steps)
def eval_or_infer_core(self, model, mode): # pylint: disable=too-many-locals, too-many-branches, too-many-statements """The core part of evaluation.""" self.do_eval = bool(mode == utils.EVAL or not self.infer_no_label) self.is_multi_output = bool(isinstance(model.preds, (tuple, list))) if self.is_multi_output: self.output_num = len(model.preds) model_path = self.get_model_path(mode) if model_path is None: logging.warning("model_path is None!") return with model.sess.graph.as_default(): model.saver.restore(model.sess, save_path=model_path) if self.first_eval: model.sess.run(tf.tables_initializer()) self.first_eval = False model.sess.run(model.iterator.initializer) # Evaluating loop. data_size = self.config["data"]['{}_data_size'.format(mode)] num_batch_every_epoch = int(math.ceil(data_size / self.batch_size)) all_fetch_vals = [] logging.info("Total eval data size: {}," "batch num per epoch: {}".format(data_size, num_batch_every_epoch)) for i in range(num_batch_every_epoch): if self.do_eval: if self.is_multi_output: fetch_ops = model.loss + list(model.logits) + list( model.preds) + list(model.y_ground_truth) else: fetch_ops = [ model.loss, model.logits, model.preds, model.y_ground_truth ] else: fetch_ops = [model.logits, model.preds] logging.debug("fetch_ops: {}".format(fetch_ops)) fetch_vals = model.sess.run(fetch_ops) end_id = (i + 1) * self.batch_size if data_size < end_id: logging.debug("data_size: {}, end_id: {}".format(data_size, end_id)) act_end_id = self.batch_size - end_id + data_size new_fetch_vals = [] for fetch_val in fetch_vals: if np.isscalar(fetch_val): new_fetch_vals.append(fetch_val) else: new_fetch_vals.append(fetch_val[:act_end_id]) else: new_fetch_vals = fetch_vals all_fetch_vals.append(new_fetch_vals) if i % self.print_every == 0 or i == num_batch_every_epoch - 1: logging.info("Evaluation rate of " "progress: [ {:.2%} ]".format( i / (num_batch_every_epoch - 1))) all_fetch_nps = [] for one_fetch_vals in zip(*all_fetch_vals): if len(np.shape(one_fetch_vals[0])) <= 0: # pylint: disable=len-as-condition one_fetch_np = one_fetch_vals else: one_fetch_np = np.concatenate(one_fetch_vals, axis=0) all_fetch_nps.append(one_fetch_np) # reshape for multi-output if self.is_multi_output: logging.debug("all_fetch_nps before reshape: {}".format( len(all_fetch_nps))) new_all_fetch_nps = [] sub_fetch_nps = [] for one_fetch_np in all_fetch_nps: sub_fetch_nps.append(one_fetch_np) if len(sub_fetch_nps) == self.output_num: new_all_fetch_nps.append(sub_fetch_nps) sub_fetch_nps = [] logging.debug("new_all_fetch_nps after reshape: {}".format( len(new_all_fetch_nps))) else: new_all_fetch_nps = all_fetch_nps if self.do_eval: _, _, preds_val, y_ground_truth_val = new_all_fetch_nps run_metrics(self.config, preds_val, y_ground_truth_val, mode) if mode == utils.INFER: if self.do_eval: _, logits_val, preds_val, _ = new_all_fetch_nps else: logits_val, preds_val = new_all_fetch_nps postproc_fn = self.postproc_fn() logging.info(postproc_fn) if isinstance(postproc_fn, list): for i, one_postproc_fn in enumerate(postproc_fn): predictions = { "logits": logits_val[i], "preds": preds_val[i], "output_index": i } one_postproc_fn(predictions, log_verbose=False) else: predictions = { "logits": logits_val, "preds": preds_val, "output_index": None } postproc_fn(predictions, log_verbose=False)
def eval_or_infer_core(self, model, mode): # pylint: disable=too-many-locals, too-many-branches """The core part of evaluation.""" model_path = self.get_model_path(mode) if model_path is None: logging.warning("model_path is None!") return with model.sess.graph.as_default(): model.saver.restore(model.sess, save_path=model_path) if self.first_eval: model.sess.run(tf.tables_initializer()) self.first_eval = False model.sess.run(model.iterator.initializer) # Evaluating loop. total_loss = 0.0 data_size = self.config["data"]['{}_data_size'.format(mode)] num_batch_every_epoch = int(math.ceil(data_size / self.batch_size)) y_ground_truth = [] y_preds = [] for i in range(num_batch_every_epoch): if mode == utils.EVAL: loss_val, \ batch_preds, \ batch_y_ground_truth = model.sess.run( [model.loss, model.preds, model.y_ground_truth]) elif not self.infer_no_label: batch_preds, \ batch_y_ground_truth = model.sess.run( [model.preds, model.y_ground_truth]) else: batch_preds = model.sess.run([model.preds]) batch_preds = batch_preds[0] if mode == utils.EVAL: total_loss += loss_val y_preds.append([preds for preds in batch_preds]) else: end_id = (i + 1) * self.batch_size if data_size < end_id: act_end_id = self.batch_size - end_id + data_size batch_preds = batch_preds[:act_end_id] if not self.infer_no_label: batch_y_ground_truth = batch_y_ground_truth[:act_end_id] y_preds.extend([preds for preds in batch_preds]) if not self.infer_no_label: y_ground_truth.extend( [ground_truth for ground_truth in batch_y_ground_truth]) if i % 10 == 0 or i == num_batch_every_epoch - 1: logging.info("Evaluation rate of " "progress: [ {:.2%} ]".format( i / (num_batch_every_epoch - 1))) if mode == utils.EVAL: logging.info("Evaluation Average Loss: {:.6}".format(total_loss / len(y_preds))) else: predictions = {"preds": y_preds} self.postproc_fn()(predictions, log_verbose=False) if not self.infer_no_label: metcs = metrics.get_metrics( config=self.config, y_pred=y_preds, y_true=y_ground_truth) logging.info("Evaluation on %s:" % mode) # add sort function to make sequence of metrics identical. for key in sorted(metcs.keys()): logging.info(key + ":" + str(metcs[key]))