Пример #1
0
  def _restore(self, path):
    """Restores this estimator from given path.

    Note: will rebuild the graph and initialize all parameters,
    and will ignore provided model.

    Args:
      path: Path to checkpoints and other information.
    """
    # Currently Saver requires absolute path to work correctly.
    path = os.path.abspath(path)

    self._graph = ops.Graph()
    with self._graph.as_default():
      endpoints_filename = os.path.join(path, 'endpoints')
      if not os.path.exists(endpoints_filename):
        raise ValueError("Restore folder doesn't contain endpoints.")
      with gfile.Open(endpoints_filename) as foutputs:
        endpoints = foutputs.read().split('\n')
      graph_filename = os.path.join(path, 'graph.pbtxt')
      if not os.path.exists(graph_filename):
        raise ValueError("Restore folder doesn't contain graph definition.")
      with gfile.Open(graph_filename) as fgraph:
        graph_def = graph_pb2.GraphDef()
        text_format.Merge(fgraph.read(), graph_def)
        (self._inp, self._out, self._model_predictions,
         self._model_loss) = importer.import_graph_def(
             graph_def, name='', return_elements=endpoints)
      saver_filename = os.path.join(path, 'saver.pbtxt')
      if not os.path.exists(saver_filename):
        raise ValueError("Restore folder doesn't contain saver definition.")
      with gfile.Open(saver_filename) as fsaver:
        saver_def = train.SaverDef()
        text_format.Merge(fsaver.read(), saver_def)
        self._saver = train.Saver(saver_def=saver_def)

      # Restore trainer
      self._global_step = self._graph.get_tensor_by_name('global_step:0')
      self._train = self._graph.get_operation_by_name('OptimizeLoss/train')

      # Restore summaries.
      self._summaries = self._graph.get_operation_by_name(
          'MergeSummary/MergeSummary')

      # Restore session.
      if not isinstance(self._config, RunConfig):
        self._config = RunConfig(verbose=self.verbose)
      self._session = session.Session(self._config.master,
                                      config=self._config.tf_config)
      checkpoint_path = train.latest_checkpoint(path)
      if checkpoint_path is None:
        raise ValueError(
            'Missing checkpoint files in the %s. Please '
            'make sure you are you have checkpoint file that describes '
            'latest checkpoints and appropriate checkpoints are there. '
            'If you have moved the folder, you at this point need to '
            'update manually update the paths in the checkpoint file.' % path)
      self._saver.restore(self._session, checkpoint_path)
    # Set to be initialized.
    self._initialized = True
Пример #2
0
    def _setup_training(self):
        """Sets up graph, model and trainer."""
        # Create config if not given.
        if self._config is None:
            self._config = RunConfig(verbose=self.verbose)
        # Create new graph.
        self._graph = ops.Graph()
        self._graph.add_to_collection("IS_TRAINING", True)
        with self._graph.as_default():
            random_seed.set_random_seed(self._config.tf_random_seed)
            self._global_step = variables.Variable(0,
                                                   name="global_step",
                                                   trainable=False)

            # Setting up inputs and outputs.
            self._inp, self._out = self._data_feeder.input_builder()

            # If class weights are provided, add them to the graph.
            # Different loss functions can use this tensor by name.
            if self.class_weight:
                self._class_weight_node = constant_op.constant(
                    self.class_weight, name='class_weight')

            # Add histograms for X and y if they are floats.
            if self._data_feeder.input_dtype in (np.float32, np.float64):
                logging_ops.histogram_summary("X", self._inp)
            if self._data_feeder.output_dtype in (np.float32, np.float64):
                logging_ops.histogram_summary("y", self._out)

            # Create model's graph.
            self._model_predictions, self._model_loss = self.model_fn(
                self._inp, self._out)

            # Set up a single operator to merge all the summaries
            self._summaries = logging_ops.merge_all_summaries()

            # Create trainer and augment graph with gradients and optimizer.
            # Additionally creates initialization ops.
            learning_rate = self.learning_rate
            optimizer = self.optimizer
            if callable(learning_rate):
                learning_rate = learning_rate(self._global_step)
            if callable(optimizer):
                optimizer = optimizer(learning_rate)
            self._train = optimizers.optimize_loss(
                self._model_loss,
                self._global_step,
                learning_rate=learning_rate,
                optimizer=optimizer,
                clip_gradients=self.clip_gradients)

            # Update ops during training, e.g. batch_norm_ops
            self._train = control_flow_ops.group(
                self._train, *ops.get_collection('update_ops'))

            # Get all initializers for all trainable variables.
            self._initializers = variables.initialize_all_variables()

            # Create model's saver capturing all the nodes created up until now.
            self._saver = train.Saver(
                max_to_keep=self._config.keep_checkpoint_max,
                keep_checkpoint_every_n_hours=self._config.
                keep_checkpoint_every_n_hours)

            # Enable monitor to create validation data dict with appropriate tf placeholders
            self._monitor.create_val_feed_dict(self._inp, self._out)

            # Create session to run model with.
            self._session = session.Session(self._config.tf_master,
                                            config=self._config.tf_config)

            # Run parameter initializers.
            self._session.run(self._initializers)
Пример #3
0
            enqueue_many=True,
            allow_smaller_final_batch=True
        )

    return features, goal


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Input Arguments
    parser.add_argument(
        '--job-dir',
        help='GCS location to write checkpoints and export models',
        required=True
    )
    args = parser.parse_args()

    config = RunConfig(model_dir=args.job_dir)

    sys.stdout.write("build_estimator...");
    sys.stdout.flush()
    build_estimator(config)
    sys.stdout.write("done\n");
    sys.stdout.flush()

    sys.stdout.write("generate_input_fn...");
    sys.stdout.flush()
    generate_input_fn()
    sys.stdout.write("done\n");
    sys.stdout.flush()