Example #1
0
    def dump(self, model_dir):
        """Dumps the options to a file in the model directory.

    Args:
      model_dir: Path to the model directory. The options will be
      dumped into a file in this directory.
    """
        gfile.MakeDirs(model_dir)
        options_dict = {
            "model_class": self.model_class,
            "model_params": self.model_params,
        }

        with gfile.GFile(TrainOptions.path(model_dir), "w") as file:
            file.write(json.dumps(options_dict).encode("utf-8"))
Example #2
0
def _get_unk_mapping(filename):
    """Reads a file that specifies a mapping from source to target tokens.
  The file must contain lines of the form <source>\t<target>"

  Args:
    filename: path to the mapping file

  Returns:
    A dictionary that maps from source -> target tokens.
  """
    with gfile.GFile(filename, "r") as mapping_file:
        lines = mapping_file.readlines()
        mapping = dict([_.split("\t")[0:2] for _ in lines])
        mapping = {k.strip(): v.strip() for k, v in mapping.items()}
    return mapping
Example #3
0
def download_mldata(dataset, save_dir):
    # Use scikit to grab datasets and save them save_dir.
    filename = os.path.join(save_dir, dataset[1] + '.pkl')

    if not gfile.Exists(save_dir):
        gfile.MkDir(save_dir)
    if not gfile.Exists(filename):
        if dataset[0][-3:] == 'csv':
            data = get_csv_data(dataset[0])
        elif dataset[0] == 'breast_cancer':
            data = load_breast_cancer()
        elif dataset[0] == 'iris':
            data = load_iris()
        elif dataset[0] == 'newsgroup':
            # Removing header information to make sure that no newsgroup identifying
            # information is included in data
            data = fetch_20newsgroups_vectorized(
                subset='all', remove=('headers'))
            tfidf = TfidfTransformer(norm='l2')
            X = tfidf.fit_transform(data.data)
            data.data = X
        elif dataset[0] == 'rcv1':
            sklearn.datasets.rcv1.URL = (
                'http://www.ai.mit.edu/projects/jmlr/papers/'
                'volume5/lewis04a/a13-vector-files/lyrl2004_vectors')
            sklearn.datasets.rcv1.URL_topics = (
                'http://www.ai.mit.edu/projects/jmlr/papers/'
                'volume5/lewis04a/a08-topic-qrels/rcv1-v2.topics.qrels.gz')
            data = sklearn.datasets.fetch_rcv1(data_home='/tmp')
        elif dataset[0] == 'wikipedia_attack':
            data = get_wikipedia_talk_data()
        elif dataset[0] == 'cifar10':
            data = get_cifar10()
        elif 'keras' in dataset[0]:
            data = get_keras_data(dataset[0])
        else:
            try:
                data = fetch_mldata(dataset[0])
            except:
                raise Exception('ERROR: failed to fetch data from mldata.org')
        X = data.data
        y = data.target
        if X.shape[0] != y.shape[0]:
            X = np.transpose(X)
        assert X.shape[0] == y.shape[0]

        data = {'data': X, 'target': y}
        pickle.dump(data, gfile.GFile(filename, 'w'))
 def _load_img_feature_pickle(self):
     for filepath in self._all_img_feature_filepaths:
         logging.info("loading %s" % filepath)
         with gfile.GFile(filepath, 'rb') as f:
             filenames, features = pickle.load(f)
             self._img_feature_filenames += filenames
             self._img_feature_data.append(features)
     self._img_feature_data = np.vstack(self._img_feature_data)
     origin_shape = self._img_feature_data.shape
     self._img_feature_data = np.reshape(self._img_feature_data,
                                         (origin_shape[0], origin_shape[3]))
     self._img_feature_filenames = np.asarray(self._img_feature_filenames)
     print(self._img_feature_data.shape)
     print(self._img_feature_filenames.shape)
     if not self._deterministic:
         self._random_shuffle()
    def load(model_dir):
        """ Loads model configurations.

        Args:
            model_dir: A string, the directory.

        Returns: A dict.
        """
        model_config_filename = os.path.join(
            model_dir, Constants.MODEL_CONFIG_YAML_FILENAME)
        if not gfile.Exists(model_config_filename):
            raise OSError("Fail to find model config file: %s" %
                          model_config_filename)
        with gfile.GFile(model_config_filename, "r") as file:
            model_configs = yaml.load(file)
        return model_configs
Example #6
0
 def _load_img_feature_pickle(self):
     """Loads img feature data from pickle."""
     for filepath in self._all_img_feature_filepaths:
         logging.info("loading %s" % filepath)
         with gfile.GFile(filepath, 'rb') as f:
             filenames, features = pickle.load(f)
             self._img_feature_filenames += filenames
             self._img_feature_data.append(features)
     # [#(1000, 1, 1, 2048), #(1000, 1, 1, 2048)] ->#(2000, 1, 1, 2048)
     self._img_feature_data = np.vstack(self._img_feature_data)
     origin_shape = self._img_feature_data.shape
     self._img_feature_data = np.reshape(self._img_feature_data,
                                         (origin_shape[0], origin_shape[3]))
     self._img_feature_filenames = np.asarray(self._img_feature_filenames)
     print(self._img_feature_data.shape)
     print(self._img_feature_filenames.shape)
Example #7
0
    def load(self, filename):
        """Loads the vocabulary from the file.

        Args:
            filename (str): Path to the vocabulary file.

        Returns:
            A tuple of TF and python mapping tables between word string and
            index, (:attr:`id_to_token_map`, :attr:`token_to_id_map`,
            :attr:`id_to_token_map_py`, :attr:`token_to_id_map_py`), where
            :attr:`id_to_token_map` and :attr:`token_to_id_map` are
            TF :tf_main:`HashTable <contrib/lookup/HashTable>` instances,
            and :attr:`id_to_token_map_py` and
            :attr:`token_to_id_map_py` are python `defaultdict` instances.
        """
        with gfile.GFile(filename) as vocab_file:
            # Converts to 'unicode' (Python 2) or 'str' (Python 3)
            vocab = list(tf.compat.as_text(line.strip()) for line in vocab_file)


        # Places _pad_token at the beginning to make sure it take index 0.
        # Must make sure this is consistent with the above line
        unk_token_idx = unk_token_id
        vocab_size = len(vocab)
        vocab_idx = np.arange(vocab_size)

        # Creates TF maps
        id_to_token_map = tf.contrib.lookup.HashTable(
            tf.contrib.lookup.KeyValueTensorInitializer(
                vocab_idx, vocab, key_dtype=tf.int64, value_dtype=tf.string),
            self._unk_token)

        token_to_id_map = tf.contrib.lookup.HashTable(
            tf.contrib.lookup.KeyValueTensorInitializer(
                vocab, vocab_idx, key_dtype=tf.string, value_dtype=tf.int64),
            unk_token_idx)

        # Creates python maps to interface with python code
        id_to_token_map_py = _make_defaultdict(
            vocab_idx, vocab, self._unk_token)
        token_to_id_map_py = _make_defaultdict(
            vocab, vocab_idx, unk_token_idx)

        logger.info("vocab size: {}/{}".format( len(token_to_id_map_py),len(id_to_token_map_py) ) )

        return id_to_token_map, token_to_id_map, \
               id_to_token_map_py, token_to_id_map_py
Example #8
0
def main(_argv):
  """The entrypoint for the script"""

  # Parse YAML FLAGS
  FLAGS.hooks = _maybe_load_yaml(FLAGS.hooks)
  FLAGS.metrics = _maybe_load_yaml(FLAGS.metrics)
  FLAGS.model_params = _maybe_load_yaml(FLAGS.model_params)
  FLAGS.input_pipeline_train = _maybe_load_yaml(FLAGS.input_pipeline_train)
  FLAGS.input_pipeline_dev = _maybe_load_yaml(FLAGS.input_pipeline_dev)

  # Load flags from config file
  final_config = {}
  if FLAGS.config_paths:
    for config_path in FLAGS.config_paths.split(","):
      config_path = config_path.strip()
      if not config_path:
        continue
      config_path = os.path.abspath(config_path)
      tf.logging.info("Loading config from %s", config_path)
      with gfile.GFile(config_path.strip()) as config_file:
        config_flags = yaml.load(config_file)
        final_config =  _deep_merge_dict(final_config, config_flags)

  tf.logging.info("Final Config:\n%s", yaml.dump(final_config))

  # Merge flags with config values
  for flag_key, flag_value in final_config.items():
    if hasattr(FLAGS, flag_key) and isinstance(getattr(FLAGS, flag_key), dict):
      merged_value = _deep_merge_dict(flag_value, getattr(FLAGS, flag_key))
      setattr(FLAGS, flag_key, merged_value)
    else:
      setattr(FLAGS, flag_key, flag_value)

  if FLAGS.save_checkpoints_secs is None \
    and FLAGS.save_checkpoints_steps is None:
    FLAGS.save_checkpoints_secs = 600
    tf.logging.info("Setting save_checkpoints_secs to %d",
                    FLAGS.save_checkpoints_secs)

  if not FLAGS.output_dir:
    FLAGS.output_dir = tempfile.mkdtemp()

  learn_runner.run(
      experiment_fn=create_experiment,
      output_dir=FLAGS.output_dir,
      schedule=FLAGS.schedule)
Example #9
0
def get_output_shape_and_type(model_file, output_name):
    with gfile.GFile(model_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    tf.graph_util.import_graph_def(graph_def)

    # prepend import name gets added when calling import_graph_def
    output = tf.get_default_session().graph.get_tensor_by_name('import/' +
                                                               output_name +
                                                               ":0")
    shape = list(output.get_shape())

    # TODO i think this can just be inferred via the custom op
    shape[0] = batch_size

    return shape, output.dtype
 def load_frozen_graph(self, model_path):
     '''
     Function to load the frozen protobuf file from the disk and parse it
     to retrieve the unserialized graph_def
     Arguments -
         model_path      : A string having the path of the tensorflow model(.pb).
     Returns -
         detection_graph : The unserialized graph_def that holds the network architecture.
     '''
     detection_graph = Graph()
     with detection_graph.as_default():
         od_graph_def = GraphDef()
         with gfile.GFile(model_path, 'rb') as fid:
             serialized_graph = fid.read()
             od_graph_def.ParseFromString(serialized_graph)
             import_graph_def(od_graph_def, name='')
     return detection_graph
Example #11
0
def _load_image(image_path, height, width, ann_dir):
    try:
        with tf_reader.GFile(image_path, 'rb') as fl:
            image_bytes = fl.read()

            image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)

            if ann_dir != "":
                image = _get_ann_images(path, image, ann_dir)
            image = cv2.resize(image, (width, height),
                               interpolation=cv2.INTER_AREA)
            im = np.array(img_to_array(image) / 255.)
            return im

    except Exception as e:
        print("Error Processing Image: %s\n %s" % (image_path, str(e)))
        return
Example #12
0
def parse_token_file(token_file):
    '''
    做一个 从图像到描述的字典 {图像:[描述1, 描述2, ……]}
    :param token_file: token文件
    :return: 字典
    '''
    img_name_to_tokens = {}
    with gfile.GFile(token_file, 'r') as f:
        lines = f.readlines()

    for line in lines:
        img_id, description = line.strip('\r\n').split('\t')
        img_name, _ = img_id.split('#')
        img_name_to_tokens.setdefault(img_name, [])
        img_name_to_tokens[img_name].append(description)

    return img_name_to_tokens
def read_dataframe_from_hdf5(path, key='data'):
    """Read a DataFrame from the given HDF5 file.

  Args:
    path: string path where the DataFrame is saved.
    key: optional string name for the DataFrame in the HDF5 file.

  Returns:
    pandas.DataFrame loaded from the HDF5 file.
  """
    with gfile.GFile(path, 'rb') as f:
        with pandas.HDFStore('in_memory',
                             mode='r',
                             driver='H5FD_CORE',
                             driver_core_backing_store=0,
                             driver_core_image=f.read()) as store:
            return store[key]
Example #14
0
 def _read_dict(self, filename):
     with gfile.GFile(filename, 'r') as f:
         lines = f.readlines()
     for line in lines:
         word, occurrence = line.strip('\r\n').split('\t')
         occurrence = int(occurrence)
         if occurrence < self._word_num_threshold:
             continue
         idx = len(self._id_to_word)
         if word == '<UNK>':
             self._unk = idx
         elif word == '.':
             self._eos = idx
         if word in self._word_to_id or idx in self._id_to_word:
             raise Exception("duplicate words in vocab.")
         self._word_to_id[word] = idx
         self._id_to_word[idx] = word
Example #15
0
 def process(self, element):
     labels = {
         constants.SUBDIR_POSITIVE: constants.POSITIVE_SENTIMENT_LABEL,
         constants.SUBDIR_NEGATIVE: constants.NEGATIVE_SENTIMENT_LABEL
     }
     found_labels = [labels[l] for l in labels if l in element]
     if len(found_labels) > 1:
         raise ValueError('Incompatible path: `{}`.'.format(element))
     if found_labels:
         with gfile.GFile(element, 'r') as single_file:
             for line in single_file:
                 yield {
                     constants.LABELS: found_labels[0],
                     constants.REVIEW: line
                 }
     else:
         logging.debug('Label not found for file: `%s`.', element)
def _parse_lines(path):
    """Parses lines from IWSLT17 dataset."""
    lines = []
    with gfile.GFile(path) as fp:
        for line in fp:
            line = line.strip()
            # Skip lines that are tags entirely.
            if _WHOLE_TAG_REGEX.match(line):
                continue
            # Try to parse as content between an opening and closing tags.
            match = _FLAT_HTML_REGEX.match(line)
            # Always append text not contained between the tags.
            if match is None:
                lines.append(line)
            elif (match.group(1) == match.group(3)
                  and match.group(1).lower() in _ALLOWED_TAGS):
                lines.append(match.group(2).strip())
    return lines
Example #17
0
    def make_feeding_data(self):
        """ Processes the data files and return an iterable
              instance for loop.

        Returns: An iterable instance.
        """
        if self._features_file is None or self._labels_file is None:
            raise ValueError("Both _features_file and _labels_file should be provided.")
        if not hasattr(self, "_parallel_data"):
            line_count = 0
            with gfile.GFile(self._features_file) as fp:
                for _ in fp:
                    line_count += 1
            if line_count > self._cache_size or self._batch_tokens_size is not None:
                setattr(self, "_parallel_data", self._BigParallelData(self))
            else:
                setattr(self, "_parallel_data", self._SmallParallelData())
        return self._parallel_data
Example #18
0
    def _build_embedding_matrix(self):
        """Builds the embedding matrix for the model.

    Returns:
      words: a list of strings representing the words in the vocabulary.
      embeddings: a float32 array of shape [vocab_size, embeddings_dim].
    """
        logging.info('Loading Glove embeddings.')
        words = []
        embeddings = []
        with gfile.GFile(FLAGS.glove_path) as f:
            for line in f:
                values = line.split()
                words.append(values[0])
                embeddings.append(np.asarray(values[1:], dtype='float32'))

        logging.info('Found %s word vectors.', len(embeddings))
        return words, np.array(embeddings)
Example #19
0
    def test_begin(self):
        model_dir = tempfile.mkdtemp()
        outfile = tempfile.NamedTemporaryFile()
        tf.get_variable("weigths", [128, 128])
        hook = hooks.PrintModelAnalysisHook(
            params={},
            model_dir=model_dir,
            run_config=tf.contrib.learn.RunConfig())
        hook.begin()

        with gfile.GFile(os.path.join(model_dir,
                                      "model_analysis.txt")) as file:
            file_contents = file.read().strip()

        self.assertEqual(
            file_contents.decode(), "_TFProfRoot (--/16.38k params)\n"
            "  weigths (128x128, 16.38k/16.38k params)")
        outfile.close()
def create_checkpoint():
    for sub_dir in ['work_dir', 'golden_chunks']:
        ensure_dir_exists(os.path.join(FLAGS.checkpoint_dir, sub_dir))

    # List all the training checkpoints.
    pattern = os.path.join(FLAGS.base_dir, 'work_dir', 'model.ckpt-*.index')
    model_paths = glob.glob(pattern)

    # Sort the checkpoints by step number.
    def extract_step(path):
        name = os.path.splitext(os.path.basename(path))[0]
        return int(re.match('model.ckpt-(\d+)', name).group(1))

    model_paths.sort(key=lambda x: extract_step(x))

    # Get the name of the latest checkpoint.
    step = extract_step(model_paths[-1])
    name = 'model.ckpt-{}'.format(step)

    # Copy the model to the checkpoint directory.
    for ext in ['.data-00000-of-00001', '.index', '.meta']:
        basename = name + ext
        src_path = os.path.join(FLAGS.base_dir, 'work_dir', basename)
        dst_path = os.path.join(FLAGS.checkpoint_dir, 'work_dir', basename)
        print('Copying {} {}'.format(src_path, dst_path))
        shutil.copy(src_path, dst_path)

    # Write the checkpoint state proto.
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'work_dir',
                                   'checkpoint')
    print('Writing {}'.format(checkpoint_path))
    with gfile.GFile(checkpoint_path, 'w') as f:
        f.write('model_checkpoint_path: "{}"\n'.format(name))
        f.write('all_model_checkpoint_paths: "{}"\n'.format(name))

    # Copy the most recent golden chunks.
    pattern = os.path.join(FLAGS.base_dir, 'data', 'golden_chunks',
                           '*.tfrecord.zz')
    src_paths = sorted(glob.glob(pattern))[-FLAGS.window_size:]
    for i, src_path in enumerate(src_paths):
        dst_path = os.path.join(FLAGS.checkpoint_dir, 'golden_chunks',
                                '000000-{:06}.tfrecord.zz'.format(i))
        print('Copying {} {}'.format(src_path, dst_path))
        shutil.copy(src_path, dst_path)
Example #21
0
    def _make_feeding_data(self,
                           features_file,
                           labels_file,
                           maximum_features_length=None,
                           maximum_labels_length=None,
                           maximum_encoded_features_length=None,
                           maximum_encoded_labels_length=None):
        """ Processes the data files and return an iterable
              instance for loop.

        Args:
            features_file: The path of features file.
            labels_file: The path of labels file.
            maximum_features_length: The maximum sequence length of "features" field.
              If provided, sentences exceeding this value will be ignore.
            maximum_labels_length: The maximum sequence length of "labels" field.
              If provided, sentences exceeding this value will be ignore.
            maximum_encoded_features_length: The maximum length of feature symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.
            maximum_encoded_labels_length: The maximum length of label symbols (especially
              after BPE is applied) . If provided, the number of symbols of one sentence
              exceeding this value will be ignore.

        Returns: An iterable instance.
        """
        if features_file is None or labels_file is None:
            raise ValueError(
                "Both features_file and labels_file should be provided.")
        line_count = 0
        with gfile.GFile(features_file) as fp:
            for _ in fp:
                line_count += 1
        if line_count > self._cache_size or self._batch_tokens_size is not None:
            return self._BigParallelData(self, features_file, labels_file,
                                         maximum_features_length,
                                         maximum_labels_length,
                                         maximum_encoded_features_length,
                                         maximum_encoded_labels_length)
        return self._SmallParallelData(features_file, labels_file,
                                       maximum_features_length,
                                       maximum_labels_length,
                                       maximum_encoded_features_length,
                                       maximum_encoded_labels_length)
Example #22
0
def create_vocabulary_mapping(filename):
    """Creates a mapping for a vocabulary file.

    Args:
      filename: Path to a vocabulary file containg one word per line.
        Each word is mapped to its line number.
      default_value: UNK tokens will be mapped to this id.
        If None, UNK tokens will be mapped to [vocab_size]

      Returns:
        A tuple (vocab_to_id_table, id_to_vocab_table,
        word_to_count_table, vocab_size). The vocab size does not include
        the UNK token.
    """
    if not gfile.Exists(filename):
        raise ValueError("File does not exist: {}".format(filename))

    # Load vocabulary into memory
    with gfile.GFile(filename) as file:
        vocab = list(line.strip("\n") for line in file)
    vocab_size = len(vocab)

    has_counts = len(vocab[0].split("\t")) == 2
    if has_counts:
        vocab, counts = zip(*[_.split("\t") for _ in vocab])
        counts = [float(_) for _ in counts]
        vocab = list(vocab)
    else:
        counts = [-1. for _ in vocab]

    print("Creating vocabulary lookup table of size %d", vocab_size)

    vocab_idx = range(vocab_size)

    vocab_to_id_mapping = OrderedDict()
    id_to_vocab_mapping = OrderedDict()
    word_to_count_mapping = OrderedDict()
    for word, count, idx in zip(vocab, counts, vocab_idx):
        vocab_to_id_mapping[word] = idx
        id_to_vocab_mapping[idx] = word
        word_to_count_mapping[word] = count

    return vocab_to_id_mapping, id_to_vocab_mapping, word_to_count_mapping, vocab_size
Example #23
0
def create_vocabulary(
    vocabulary_file, raw_data_dir, max_vocabulary_size, Isch=True, normalize_digits=True
):
    texts, textssz = get_ch_lable(raw_data_dir, Isch, normalize_digits)
    all_words = []
    for label in texts:
        all_words += [word for word in label]
    training_label, count, dictionary, reverse_dictionary = build_dataset(
        all_words, max_vocabulary_size
    )
    if not gfile.Exists(vocabulary_file):
        if len(reverse_dictionary) > max_vocabulary_size:
            reverse_dictionary = reverse_dictionary[:max_vocabulary_size]
            with gfile.GFile(vocabulary_file, mode='w') as vocab_file:
                for w in reverse_dictionary:
                    vocab_file.write(reverse_dictionary[w] + '\n')
    else:
        print('')
    return training_label, count, dictionary, reverse_dictionary, textssz
Example #24
0
def _load_yaml(blueoil_config_filename):
    """load blueoil config yaml

    Args:
        blueoil_config_filename(str): File path of blueoil config yaml file.

    Returns:
        blueoil_config(dict): dict of blueoil config.
    """
    if not gfile.Exists(blueoil_config_filename):
        FileNotFoundError("File not found: {}".format(blueoil_config_filename))

    with gfile.GFile(blueoil_config_filename, "r") as f:
        blueoil_config = yaml.load(f, Loader=yaml.SafeLoader)

    model_name, _ = os.path.splitext(os.path.basename(blueoil_config_filename))

    blueoil_config["model_name"] = model_name

    return blueoil_config
Example #25
0
  def test_begin(self):
    model_dir = tempfile.mkdtemp()
    outfile = tempfile.NamedTemporaryFile()
    tf.get_variable("weights", [128, 128])
    hook = hooks.PrintModelAnalysisHook(
        params={}, model_dir=model_dir, run_config=tf.contrib.learn.RunConfig())
    hook.begin()

    with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file:
      file_contents = file.read().strip()

    lines = tf.compat.as_text(file_contents).split("\n")
    if len(lines) == 3:
      # TensorFlow v1.2 includes an extra header line
      self.assertEqual(lines[0], "node name | # parameters")

    self.assertEqual(lines[-2], "_TFProfRoot (--/16.38k params)")
    self.assertEqual(lines[-1], "  weights (128x128, 16.38k/16.38k params)")

    outfile.close()
Example #26
0
def ablation_visualization(x1,
                           x2,
                           gen,
                           z_dim,
                           basedir,
                           global_step,
                           figsize=(20, 20),
                           show=False):
    images = generate(x1, x2, gen, z_dim, 3, 12)
    plt.figure(figsize=figsize)
    plt.imshow(grid(images, 12, 1), cmap='Greys_r', interpolation=None)
    plt.axis('off')
    if show:
        plt.show()

    filename = os.path.join(basedir, 'ablation_{:09d}.png'.format(global_step))
    with gfile.GFile(filename, mode='w') as f:
        plt.savefig(f, dpi=100, bbox_inches='tight')

    plt.close()
Example #27
0
def dump_model_analysis(model_dir):
    """ Dumps detailed model size.

    Args:
        model_dir: The directory name to save to.
    """
    # Dump to file on the chief worker
    filename = os.path.join(model_dir, Constants.MODEL_ANALYSIS_FILENAME)
    profile_opt_builder = tf.profiler.ProfileOptionBuilder
    opts = profile_opt_builder.trainable_variables_parameter()
    opts["output"] = "file:outfile={}".format(filename)
    param_stats = tf.profiler.profile(tf.get_default_graph(), options=opts)
    # following APIs are deprecated
    # opts = tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
    # opts['dump_to_file'] = os.path.abspath(filename)
    # tf.contrib.tfprof.model_analyzer.print_model_analysis(
    #     tf.get_default_graph(), tfprof_options=opts)
    # Print the model analysis
    with gfile.GFile(filename) as file:
        tf.logging.info(file.read())
Example #28
0
def get_prediction_input(files):
    """Reads and concatenates text files in input directory.

  Args:
    files: List of `str`, containing absolute path to files to read.

  Returns:
    List of `str` containing independent text reviews.

  Raises:
    ValueError: If input files are empty.
  """

    instances = []
    for path in files:
        with gfile.GFile(path, 'r') as lines:
            instances += lines
    if not instances:
        raise ValueError('No review found in input files.')
    return instances
Example #29
0
def main(_argv):
    """The entrypoint for the script"""
    misc_utils.clean(FLAGS)
    misc_utils.make_path(FLAGS)
    log_file = os.path.join(FLAGS.model_dir, FLAGS.log_file)
    logger = misc_utils.get_logger(log_file)

    # Parse YAML FLAGS
    FLAGS.model_params = _maybe_load_yaml(FLAGS.model_params)

    # Load flags from config file
    final_config = {}
    if FLAGS.config_paths:
        for config_path in FLAGS.config_paths.split(","):
            config_path = config_path.strip()
            if not config_path:
                continue
            config_path = os.path.abspath(config_path)
            logger.info("Loading config from %s", config_path)
            with gfile.GFile(config_path.strip()) as config_file:
                config_flags = yaml.load(config_file)
                final_config = _deep_merge_dict(final_config, config_flags)

    logger.info("Final Config:\n%s", yaml.dump(final_config))

    # Merge flags with config values
    for flag_key, flag_value in final_config.items():
        if hasattr(FLAGS, flag_key) and isinstance(getattr(FLAGS, flag_key),
                                                   dict):
            merged_value = _deep_merge_dict(flag_value,
                                            getattr(FLAGS, flag_key))
            setattr(FLAGS, flag_key, merged_value)
        elif hasattr(FLAGS, flag_key):
            setattr(FLAGS, flag_key, flag_value)
        else:
            logger.warning("Ignoring config flag: %s", flag_key)

    if not FLAGS.model_dir:
        FLAGS.output_dir = tempfile.mkdtemp()

    train(FLAGS, logger)
Example #30
0
def parse_cap(cap_path):
    '''
    解析cap文件,返回img与cap的映射表
    对于flickr30k,每张图片对应5个caption,所以映射关系为{img:[cap0,...,cap4]}
    :param cap_path: caption文件路径
    :return:
    '''
    with gfile.GFile(cap_path, 'r') as fd:
        text = fd.readlines()

    img2cap = dict()

    for line in text:
        img_name, cap = line.strip().split('\t')
        img_name = img_name.split('#')[0]
        cap = cap.strip()

        img2cap.setdefault(img_name, list())
        img2cap[img_name].append(cap)

    return img2cap