Example #1
0
  def build_and_save_model(self):
    """build and save the model into self.logdir."""
    with tf.Graph().as_default(), tf.Session() as sess:
      # Build model with inputs and labels.
      inputs = tf.placeholder(tf.float32, name='input', shape=self.inputs_shape)
      outputs = self.build_model(inputs)

      # Run the model
      inputs_val = np.random.rand(*self.inputs_shape).astype(float)
      labels_val = np.zeros(self.labels_shape).astype(np.int64)
      labels_val[:, 0] = 1

      if self.ckpt_path:
        # Load the true weights if available.
        inference.restore_ckpt(sess, self.ckpt_path,
                               self.model_config.moving_average_decay,
                               self.export_ckpt)
      else:
        sess.run(tf.global_variables_initializer())
        # Run a single train step.
        sess.run(outputs, feed_dict={inputs: inputs_val})

      all_saver = tf.train.Saver(save_relative_paths=True)
      all_saver.save(sess, os.path.join(self.logdir, self.model_name))

      tf_graph = os.path.join(self.logdir, self.model_name + '_train.pb')
      with tf.io.gfile.GFile(tf_graph, 'wb') as f:
        f.write(sess.graph_def.SerializeToString())
    def freeze_model(self) -> Tuple[Text, Text]:
        """Freeze model and convert them into tflite and tf graph."""
        with tf.Graph().as_default(), tf.Session() as sess:
            inputs = tf.placeholder(tf.float32,
                                    name='input',
                                    shape=self.inputs_shape)
            outputs = self.build_model(inputs)

            if self.ckpt_path:
                # Load the true weights if available.
                inference.restore_ckpt(sess, self.ckpt_path,
                                       self.model_config.moving_average_decay,
                                       self.export_ckpt)
            else:
                # Load random weights if not checkpoint is not available.
                self.build_and_save_model()
                checkpoint = tf.train.latest_checkpoint(self.logdir)
                logging.info('Loading checkpoint: %s', checkpoint)
                saver = tf.train.Saver()
                saver.restore(sess, checkpoint)

            # export frozen graph.
            output_node_names = [node.op.name for node in outputs]
            graphdef = tf.graph_util.convert_variables_to_constants(
                sess, sess.graph_def, output_node_names)

            tf_graph = os.path.join(self.logdir,
                                    self.model_name + '_frozen.pb')
            tf.io.gfile.GFile(tf_graph,
                              'wb').write(graphdef.SerializeToString())

            # export savaed model.
            output_dict = {
                'class_predict_%d' % i: outputs[i]
                for i in range(5)
            }
            output_dict.update(
                {'box_predict_%d' % i: outputs[5 + i]
                 for i in range(5)})
            signature_def_map = {
                'serving_default':
                tf.saved_model.predict_signature_def(
                    {'input': inputs},
                    output_dict,
                )
            }
            output_dir = os.path.join(self.logdir, 'savedmodel')
            b = tf.saved_model.Builder(output_dir)
            b.add_meta_graph_and_variables(sess,
                                           tags=['serve'],
                                           signature_def_map=signature_def_map,
                                           assets_collection=tf.get_collection(
                                               tf.GraphKeys.ASSET_FILEPATHS),
                                           clear_devices=True)
            b.save()
            logging.info('Model saved at %s', output_dir)

        return graphdef
Example #3
0
 def eval_ckpt(self):
   """build and save the model into self.logdir."""
   with tf.Graph().as_default(), tf.Session() as sess:
     # Build model with inputs and labels.
     inputs = tf.placeholder(tf.float32, name='input', shape=self.inputs_shape)
     self.build_model(inputs)
     inference.restore_ckpt(sess, self.ckpt_path,
                            self.model_config.moving_average_decay,
                            self.export_ckpt)