Exemple #1
0
    def prepare(self,
                session: tf.Session,
                file_writer: tf.summary.FileWriter,
                restore_ckpt_path=None):
        self.session = session
        self.file_writer = file_writer

        self.call_train = session.make_callable(
            [self.train_op, self.loss, self.summary_logger.summary_op],
            [*self.x_phs, self.y_ph, self.lr, self.training_ph])

        self.call_val = session.make_callable(
            [self.y_pred, self.loss],
            [*self.x_phs, self.y_ph, self.training_ph])

        self.call_pred = session.make_callable(self.y_pred,
                                               [*self.x_phs, self.training_ph])

        if restore_ckpt_path is None:
            session.run(self.init_op)
        else:
            self.saver.restore(session, get_model_path(restore_ckpt_path))
Exemple #2
0
 def prepare(self, session: tf.Session, restore_ckpt_path):
     self.call_pred = session.make_callable(self.y_pred, self.x_phs)
     self.saver.restore(session, get_model_path(restore_ckpt_path))