def _restore(self, path): """Restores this estimator from given path. Note: will rebuild the graph and initialize all parameters, and will ignore provided model. Args: path: Path to checkpoints and other information. """ # Currently Saver requires absolute path to work correctly. path = os.path.abspath(path) self._graph = ops.Graph() with self._graph.as_default(): endpoints_filename = os.path.join(path, 'endpoints') if not os.path.exists(endpoints_filename): raise ValueError("Restore folder doesn't contain endpoints.") with gfile.Open(endpoints_filename) as foutputs: endpoints = foutputs.read().split('\n') graph_filename = os.path.join(path, 'graph.pbtxt') if not os.path.exists(graph_filename): raise ValueError("Restore folder doesn't contain graph definition.") with gfile.Open(graph_filename) as fgraph: graph_def = graph_pb2.GraphDef() text_format.Merge(fgraph.read(), graph_def) (self._inp, self._out, self._model_predictions, self._model_loss) = importer.import_graph_def( graph_def, name='', return_elements=endpoints) saver_filename = os.path.join(path, 'saver.pbtxt') if not os.path.exists(saver_filename): raise ValueError("Restore folder doesn't contain saver definition.") with gfile.Open(saver_filename) as fsaver: saver_def = train.SaverDef() text_format.Merge(fsaver.read(), saver_def) self._saver = train.Saver(saver_def=saver_def) # Restore trainer self._global_step = self._graph.get_tensor_by_name('global_step:0') self._train = self._graph.get_operation_by_name('OptimizeLoss/train') # Restore summaries. self._summaries = self._graph.get_operation_by_name( 'MergeSummary/MergeSummary') # Restore session. if not isinstance(self._config, RunConfig): self._config = RunConfig(verbose=self.verbose) self._session = session.Session(self._config.master, config=self._config.tf_config) checkpoint_path = train.latest_checkpoint(path) if checkpoint_path is None: raise ValueError( 'Missing checkpoint files in the %s. Please ' 'make sure you are you have checkpoint file that describes ' 'latest checkpoints and appropriate checkpoints are there. ' 'If you have moved the folder, you at this point need to ' 'update manually update the paths in the checkpoint file.' % path) self._saver.restore(self._session, checkpoint_path) # Set to be initialized. self._initialized = True
def _setup_training(self): """Sets up graph, model and trainer.""" # Create config if not given. if self._config is None: self._config = RunConfig(verbose=self.verbose) # Create new graph. self._graph = ops.Graph() self._graph.add_to_collection("IS_TRAINING", True) with self._graph.as_default(): random_seed.set_random_seed(self._config.tf_random_seed) self._global_step = variables.Variable(0, name="global_step", trainable=False) # Setting up inputs and outputs. self._inp, self._out = self._data_feeder.input_builder() # If class weights are provided, add them to the graph. # Different loss functions can use this tensor by name. if self.class_weight: self._class_weight_node = constant_op.constant( self.class_weight, name='class_weight') # Add histograms for X and y if they are floats. if self._data_feeder.input_dtype in (np.float32, np.float64): logging_ops.histogram_summary("X", self._inp) if self._data_feeder.output_dtype in (np.float32, np.float64): logging_ops.histogram_summary("y", self._out) # Create model's graph. self._model_predictions, self._model_loss = self.model_fn( self._inp, self._out) # Set up a single operator to merge all the summaries self._summaries = logging_ops.merge_all_summaries() # Create trainer and augment graph with gradients and optimizer. # Additionally creates initialization ops. learning_rate = self.learning_rate optimizer = self.optimizer if callable(learning_rate): learning_rate = learning_rate(self._global_step) if callable(optimizer): optimizer = optimizer(learning_rate) self._train = optimizers.optimize_loss( self._model_loss, self._global_step, learning_rate=learning_rate, optimizer=optimizer, clip_gradients=self.clip_gradients) # Update ops during training, e.g. batch_norm_ops self._train = control_flow_ops.group( self._train, *ops.get_collection('update_ops')) # Get all initializers for all trainable variables. self._initializers = variables.initialize_all_variables() # Create model's saver capturing all the nodes created up until now. self._saver = train.Saver( max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=self._config. keep_checkpoint_every_n_hours) # Enable monitor to create validation data dict with appropriate tf placeholders self._monitor.create_val_feed_dict(self._inp, self._out) # Create session to run model with. self._session = session.Session(self._config.tf_master, config=self._config.tf_config) # Run parameter initializers. self._session.run(self._initializers)
enqueue_many=True, allow_smaller_final_batch=True ) return features, goal if __name__ == '__main__': parser = argparse.ArgumentParser() # Input Arguments parser.add_argument( '--job-dir', help='GCS location to write checkpoints and export models', required=True ) args = parser.parse_args() config = RunConfig(model_dir=args.job_dir) sys.stdout.write("build_estimator..."); sys.stdout.flush() build_estimator(config) sys.stdout.write("done\n"); sys.stdout.flush() sys.stdout.write("generate_input_fn..."); sys.stdout.flush() generate_input_fn() sys.stdout.write("done\n"); sys.stdout.flush()