def prepare(self, session: tf.Session, file_writer: tf.summary.FileWriter, restore_ckpt_path=None): self.session = session self.file_writer = file_writer self.call_train = session.make_callable( [self.train_op, self.loss, self.summary_logger.summary_op], [*self.x_phs, self.y_ph, self.lr, self.training_ph]) self.call_val = session.make_callable( [self.y_pred, self.loss], [*self.x_phs, self.y_ph, self.training_ph]) self.call_pred = session.make_callable(self.y_pred, [*self.x_phs, self.training_ph]) if restore_ckpt_path is None: session.run(self.init_op) else: self.saver.restore(session, get_model_path(restore_ckpt_path))
def prepare(self, session: tf.Session, restore_ckpt_path): self.call_pred = session.make_callable(self.y_pred, self.x_phs) self.saver.restore(session, get_model_path(restore_ckpt_path))