def _restore(self, path): """Restores this estimator from given path. Note: will rebuild the graph and initialize all parameters, and will ignore provided model. Args: path: Path to checkpoints and other information. """ # Currently Saver requires absolute path to work correctly. path = os.path.abspath(path) self._graph = tf.Graph() with self._graph.as_default(): endpoints_filename = os.path.join(path, 'endpoints') if not os.path.exists(endpoints_filename): raise ValueError("Restore folder doesn't contain endpoints.") with gfile.Open(endpoints_filename) as foutputs: endpoints = foutputs.read().split('\n') graph_filename = os.path.join(path, 'graph.pbtxt') if not os.path.exists(graph_filename): raise ValueError("Restore folder doesn't contain graph definition.") with gfile.Open(graph_filename) as fgraph: graph_def = tf.GraphDef() text_format.Merge(fgraph.read(), graph_def) (self._inp, self._out, self._model_predictions, self._model_loss) = tf.import_graph_def( graph_def, name='', return_elements=endpoints) saver_filename = os.path.join(path, 'saver.pbtxt') if not os.path.exists(saver_filename): raise ValueError("Restore folder doesn't contain saver defintion.") with gfile.Open(saver_filename) as fsaver: saver_def = tf.train.SaverDef() text_format.Merge(fsaver.read(), saver_def) self._saver = tf.train.Saver(saver_def=saver_def) # Restore trainer self._global_step = self._graph.get_tensor_by_name('global_step:0') self._train = self._graph.get_operation_by_name('train') # Restore summaries. self._summaries = self._graph.get_operation_by_name('MergeSummary/MergeSummary') # Restore session. if not isinstance(self._config, RunConfig): self._config = RunConfig(verbose=self.verbose) self._session = tf.Session( self._config.tf_master, config=self._config.tf_config) checkpoint_path = tf.train.latest_checkpoint(path) if checkpoint_path is None: raise ValueError("Missing checkpoint files in the %s. Please " "make sure you are you have checkpoint file that describes " "latest checkpoints and appropriate checkpoints are there. " "If you have moved the folder, you at this point need to " "update manually update the paths in the checkpoint file." % path) self._saver.restore(self._session, checkpoint_path) # Set to be initialized. self._initialized = True
def save(self, filename): """Saves vocabulary processor into given file. Args: filename: Path to output file. """ with gfile.Open(filename, 'wb') as f: f.write(pickle.dumps(self))
def restore(cls, filename): """Restores vocabulary processor from given file. Args: filename: Path to file to load from. Returns: VocabularyProcessor object. """ with gfile.Open(filename, 'rb') as f: return pickle.loads(f.read())
def load_csv(filename, target_dtype): with gfile.Open(filename) as csv_file: data_file = csv.reader(csv_file) header = next(data_file) n_samples = int(header[0]) n_features = int(header[1]) target_names = np.array(header[2:]) data = np.empty((n_samples, n_features)) target = np.empty((n_samples, ), dtype=np.int) for i, ir in enumerate(data_file): data[i] = np.asarray(ir[:-1], dtype=np.float64) target[i] = np.asarray(ir[-1], dtype=target_dtype) return Dataset(data=data, target=target)
def extract_labels(filename, one_hot=False, num_classes=10): """Extract the labels into a 1D uint8 numpy array [index].""" print('Extracting', filename) with gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) if magic != 2049: raise ValueError( 'Invalid magic number %d in MNIST label file: %s' % (magic, filename)) num_items = _read32(bytestream) buf = bytestream.read(num_items) labels = numpy.frombuffer(buf, dtype=numpy.uint8) if one_hot: return dense_to_one_hot(labels, num_classes) return labels
def extract_images(filename): """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" print('Extracting', filename) with gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) if magic != 2051: raise ValueError( 'Invalid magic number %d in MNIST image file: %s' % (magic, filename)) num_images = _read32(bytestream) rows = _read32(bytestream) cols = _read32(bytestream) buf = bytestream.read(rows * cols * num_images) data = numpy.frombuffer(buf, dtype=numpy.uint8) data = data.reshape(num_images, rows, cols, 1) return data
def restore(cls, path, config=None): """Restores model from give path. Args: path: Path to the checkpoints and other model information. config: RunConfig object that controls the configurations of the session, e.g. num_cores, gpu_memory_fraction, etc. This is allowed to be reconfigured. Returns: Estiamator, object of the subclass of TensorFlowEstimator. """ model_def_filename = os.path.join(path, 'model.def') if not os.path.exists(model_def_filename): raise ValueError( "Restore folder doesn't contain model definition.") # list of parameters that are allowed to be reconfigured reconfigurable_params = ['_config'] _config = config with gfile.Open(model_def_filename) as fmodel: model_def = json.loads(fmodel.read()) # TensorFlow binding requires parameters to be strings not unicode. # Only issue in Python2. for key, value in model_def.items(): if (isinstance(value, string_types) and not isinstance(value, str)): model_def[key] = str(value) if key in reconfigurable_params: new_value = locals()[key] if new_value is not None: model_def[key] = new_value class_name = model_def.pop('class_name') if class_name == 'TensorFlowEstimator': custom_estimator = TensorFlowEstimator(model_fn=None, **model_def) custom_estimator._restore(path) return custom_estimator # To avoid cyclical dependencies, import inside the function instead of # the beginning of the file. from tensorflow.contrib.skflow.python.skflow import estimators # Estimator must be one of the defined estimators in the __init__ file. estimator = getattr(estimators, class_name)(**model_def) estimator._restore(path) return estimator
def load_csv(filename, target_dtype, target_column=-1, has_header=True): with gfile.Open(filename) as csv_file: data_file = csv.reader(csv_file) if has_header: header = next(data_file) n_samples = int(header[0]) n_features = int(header[1]) target_names = np.array(header[2:]) data = np.empty((n_samples, n_features)) target = np.empty((n_samples, ), dtype=np.int) for i, ir in enumerate(data_file): target[i] = np.asarray(ir.pop(target_column), dtype=target_dtype) data[i] = np.asarray(ir, dtype=np.float64) else: data, target = [], [] for ir in data_file: target.append(ir.pop(target_column)) data.append(ir) return Dataset(data=data, target=target)
def _write_with_backup(filename, content): if gfile.Exists(filename): gfile.Rename(filename, filename + '.old', overwrite=True) with gfile.Open(filename, 'w') as f: f.write(content)
def testOpen(self): with gfile.Open(self.tmp + "test_open", "wb") as f: f.write(b"foo") with gfile.Open(self.tmp + "test_open") as f: result = f.readlines() self.assertEqual(["foo"], result)