예제 #1
0
def _get_embedding_initializer(vocab_file, embedding_source, vocab_size):
    """Loads pretrained embeddings from a file in GloVe format."""
    with gfile.GFile(embedding_source, "r") as f:
        embedding_lines = f.readlines()

    # First line contains embedding dim.
    _, embedding_dim = list(map(int, embedding_lines[0].split()))
    # Get the tokens as strings.
    tokens = [line.split()[0] for line in embedding_lines[1:]]
    # Get the actual embedding matrix.
    unsorted_emb = np.array([[float(x) for x in line.split()[1:]]
                             for line in embedding_lines[1:]])

    # Get the expected vocab order.
    with gfile.GFile(vocab_file, "r") as f:
        tokens_order = [l.strip() for l in f.readlines()]
    assert vocab_size == len(tokens_order)

    # Put the embeddings in the order.
    sorted_emb = np.zeros((vocab_size, embedding_dim))
    for i, token in enumerate(tokens_order):
        if token in tokens:
            sorted_emb[i, :] = unsorted_emb[tokens.index(token), :]
        else:  # If we don't have a pretrained embedding, initialize randomly.
            sorted_emb[i, :] = np.random.normal(loc=0.0,
                                                scale=GLOVE_STD,
                                                size=(GLOVE_DIM, ))

    return sorted_emb.astype(np.float32)
예제 #2
0
def get_accuracy_result(questions_path, golden_answers_path,
                        inferred_answers_path):
    """Collect accuracy results from input files."""
    questions = gfile.GFile(questions_path).readlines()
    golden_answers = gfile.GFile(golden_answers_path).readlines()
    inferred_answers = gfile.GFile(inferred_answers_path).readlines()

    result = AccuracyResult(total_lines=len(questions),
                            matches=[],
                            mismatches=[],
                            inferred_answers_path=inferred_answers_path)
    if len(set(
        (len(questions), len(golden_answers), len(inferred_answers)))) > 1:
        logging.fatal(
            f'Not writing accuracy results: Input files have different '
            'lengths')
        logging.fatal(f'Questions: {len(questions)}, golden answers: '
                      '{len(golden_answers)}, inferred answers: '
                      '{len(inferred_answers)}')
        return None
    for question, golden, inferred in zip(questions, golden_answers,
                                          inferred_answers):
        if inferred == golden:
            result.matches.append((question, golden))
        else:
            result.mismatches.append((question, golden, inferred))
    return result
예제 #3
0
def write_eval_results(checkpoint_dir,
                       all_gen_sentences,
                       checkpoint_name,
                       mean_train_prob,
                       mean_valid_prob,
                       mean_gen_prob,
                       fid,
                       eval_filename=None):
    """Write evaluation results to disk."""
    if eval_filename is None:
        eval_filename = EVAL_FILENAME
    to_write = ",".join(
        map(str, [
            checkpoint_name, mean_train_prob, mean_valid_prob, mean_gen_prob,
            fid
        ]))
    eval_filepath = os.path.join(checkpoint_dir, eval_filename)
    previous_eval_content = ""
    if gfile.exists(eval_filepath):
        with gfile.GFile(eval_filepath, "r") as f:
            previous_eval_content = f.read()
    with gfile.GFile(eval_filepath, "w") as f:
        f.write(previous_eval_content + to_write + "\n")

    if all_gen_sentences is not None:
        with gfile.GFile(
                os.path.join(checkpoint_dir,
                             checkpoint_name + "_sentences.txt"), "w") as f:
            f.write("\n".join(all_gen_sentences))
예제 #4
0
    def construct_dataset(self, name, train_split):
        """Construct datasets for training and testing.

    Args:
      name: name of the learning task
      train_split: percentage of data for training
    """
        if name in ['membership', 'mixture']:
            self.train_data_membership, self.test_data_membership = \
                self.construct_membership(train_split)
        if name in ['intersection', 'union', 'mixture']:
            self.train_data_set_pair, self.test_data_set_pair = \
                self.construct_set_pairs(train_split)
        if name == 'follow':
            self.train_data_follow, self.test_data_follow = \
                self.construct_follow_facts(train_split)
        if name in ['set_follow', 'mixture']:
            self.train_data_follow, self.test_data_follow = \
                self.construct_set_follow_facts(train_split)
        if name in ['metaqa2', 'metaqa3']:
            self.train_data_metaqa, self.test_data_metaqa = \
                self.construct_metaqa_data(
                    train_file='train.json', test_file='test.json')
        if name in ['webqsp']:
            self.bert_tokenizer = util.BertTokenizer()
            self.train_data_webqsp, self.test_data_webqsp = self.construct_webqsp(
                train_file='train.json', test_file='test.json')
            self.all_entity_is_cvt = self.check_cvt_entity()
        if name.startswith('query2box'):
            self.q2b_id2entity = pickle.load(
                gfile.GFile(self.root_dir + 'ind2ent.pkl', 'rb'))
            self.q2b_id2relation = pickle.load(
                gfile.GFile(self.root_dir + 'ind2rel.pkl', 'rb'))
            self.train_data_query2box, self.test_data_query2box = \
                self.construct_query2box(name)
예제 #5
0
def load_and_drop_stream(data_file,
                         kb_file,
                         drop_incorrect=True,
                         verbose=False):
  """ this function filter incorrect samples without standardization."""
  if verbose:
    print('loading stream')
  fin_data = gfile.GFile(data_file)
  if gfile.exists(kb_file):
    fin_kb = gfile.GFile(kb_file)
  else:
    fin_kb = None
  if verbose:
    print('gfile loaded: ', fin_data)
  for line1 in fin_data:
    if verbose:
      print(line1)
    if len(line1.strip()) < 10:
      continue
    line1 = delete_non_ascii(line1)
    data_obj = json.loads(line1)

    if fin_kb:
      line2 = fin_kb.readline()
      if len(line2.strip()) < 10:
        continue
      line2 = delete_non_ascii(line2)
      kb_obj = json.loads(line2)
    else:
      kb_obj = None
    if (not drop_incorrect) or (
        'correct_sample' not in data_obj) or data_obj['correct_sample']:
      yield data_obj, kb_obj
예제 #6
0
def get_valid_data(data_path, dataset, truncate_vocab=20000):
  if dataset not in FILENAMES:
    raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
        dataset, list(FILENAMES.keys())))
  train_file, valid_file, test_file = FILENAMES[dataset]

  train_path = os.path.join(data_path, train_file)
  valid_path = os.path.join(data_path, valid_file)
  test_path = os.path.join(data_path, test_file)

  with gfile.GFile(train_path, "r") as infile:
    data_train = read_lines(infile)
  with gfile.GFile(valid_path, "r") as infile:
    data_valid = read_lines(infile)
  with gfile.GFile(test_path, "r") as infile:
    data_test = read_lines(infile)

  word_to_id = _build_vocab(data_train)
  logging.info("Full vocab length: %d", len(word_to_id))
  # Assume the vocab is sorted by frequency.
  word_to_id_truncated = {
      k: v for k, v in word_to_id.items() if v < truncate_vocab
  }
  logging.info("Truncated vocab length: %d", len(word_to_id_truncated))

  valid_data = _integerize(data_valid, word_to_id_truncated, dataset)
  test_data = _integerize(data_test, word_to_id_truncated, dataset)
  return valid_data, test_data, word_to_id_truncated
예제 #7
0
def load_and_drop(data_file, kb_file, drop_incorrect=True, verbose=False):
  """ this function filter incorrect samples without standardization."""
  fin_data = gfile.GFile(data_file)
  fin_kb = gfile.GFile(kb_file)
  total_in_file = 0
  loaded_data = []
  loaded_kb = []
  for line1 in tqdm(fin_data, desc='loading data'):
    if len(line1.strip()) < 10:
      continue
    line2 = fin_kb.readline()
    if len(line2.strip()) < 10:
      continue
    line1 = delete_non_ascii(line1)
    line2 = delete_non_ascii(line2)

    data_obj = json.loads(line1)
    kb_obj = json.loads(line2)
    if (not drop_incorrect) or (
        'correct_sample' not in data_obj) or data_obj['correct_sample']:
      loaded_data.append(data_obj)
      loaded_kb.append(kb_obj)
    total_in_file += 1

  if verbose:
    print(('loaded: ', len(loaded_data), '/', total_in_file, '=',
           len(loaded_data) * 1.0 / total_in_file))
  return loaded_data, loaded_kb
예제 #8
0
def get_raw_data(data_path, dataset, truncate_vocab=20000):
    """Load raw data from data directory "data_path".

  Reads text files, converts strings to integer ids,
  and performs mini-batching of the inputs.

  Args:
    data_path: string path to the directory where simple-examples.tgz has been
      extracted.
    dataset: one of ["emnlp2017"]
    truncate_vocab: int, number of words to keep in the vocabulary.

  Returns:
    tuple (train_data, valid_data, vocabulary) where each of the data
    objects can be passed to iterator.

  Raises:
    ValueError: dataset not in ["emnlp2017"].
  """
    if dataset not in FILENAMES:
        raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
            dataset, FILENAMES.keys()))
    train_file, valid_file, test_file = FILENAMES[dataset]

    train_path = os.path.join(data_path, train_file)
    valid_path = os.path.join(data_path, valid_file)
    test_path = os.path.join(data_path, test_file)

    with gfile.GFile(train_path, "r") as json_file:
        json_data_train = json.load(json_file)
    with gfile.GFile(valid_path, "r") as json_file:
        json_data_valid = json.load(json_file)
    with gfile.GFile(test_path, "r") as json_file:
        json_data_test = json.load(json_file)

    word_to_id = _build_vocab(json_data_train)
    logging.info("Full vocab length: %d", len(word_to_id))
    # Assume the vocab is sorted by frequency.
    word_to_id_truncated = {
        k: v
        for k, v in word_to_id.items() if v < truncate_vocab
    }
    logging.info("Truncated vocab length: %d", len(word_to_id_truncated))

    train_data = _integerize(json_data_train, word_to_id_truncated, dataset)
    valid_data = _integerize(json_data_valid, word_to_id_truncated, dataset)
    test_data = _integerize(json_data_test, word_to_id_truncated, dataset)

    with open('word_to_id_truncated.json', 'w') as f:
        json.dump(word_to_id_truncated, f)
    np.save('train_data_sequences.npy', train_data['sequences'])
    np.save('valid_data_sequences.npy', valid_data['sequences'])
    np.save('test_data_sequences.npy', test_data['sequences'])
    np.save('train_data_sequence_lengths.npy', train_data['sequence_lengths'])
    np.save('valid_data_sequence_lengths.npy', valid_data['sequence_lengths'])
    np.save('test_data_sequence_lengths.npy', test_data['sequence_lengths'])
    print('DONE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
    return train_data, valid_data, word_to_id_truncated
예제 #9
0
def maybe_pick_models_to_evaluate(checkpoint_dir):
    """Pick a checkpoint to evaluate that has not been evaluated already."""
    logging.info("Picking checkpoint to evaluate from %s.", checkpoint_dir)

    filenames = gfile.listdir(checkpoint_dir)
    filenames = [f[:-5] for f in filenames if f[-5:] == ".meta"]
    logging.info("Found existing checkpoints: %s", filenames)

    evaluated_filenames = []
    if gfile.exists(os.path.join(checkpoint_dir, EVAL_FILENAME)):
        with gfile.GFile(os.path.join(checkpoint_dir, EVAL_FILENAME),
                         "r") as f:
            evaluated_filenames = [
                l.strip().split(",")[0] for l in f.readlines()
            ]
        logging.info("Found already evaluated checkpoints: %s",
                     evaluated_filenames)

    checkpoints_to_evaluate = [
        f for f in filenames if f not in evaluated_filenames
    ]
    logging.info("Remaining potential checkpoints: %s",
                 checkpoints_to_evaluate)

    if checkpoints_to_evaluate:
        return os.path.join(checkpoint_dir, checkpoints_to_evaluate[0])
    else:
        return None
예제 #10
0
def select_ppl_results(checkpoint_dir,
                       valid_ppl,
                       valid_tokens,
                       test_ppl,
                       test_tokens,
                       eval_filename=None):
    """Write evaluation results to disk."""
    if eval_filename is None:
        eval_filename = 'ppl.csv'
    eval_filepath = os.path.join(checkpoint_dir, eval_filename)
    previous_eval_content = ""
    if gfile.exists(eval_filepath):
        with gfile.GFile(eval_filepath, "r") as f:
            previous_eval_content = f.read()

    best_ppl = None
    best_checkpoint = None
    for line in previous_eval_content.strip().split('\n'):
        fields = line.strip().split(',')
        if fields:
            ppl = float(fields[1])
            if best_ppl is None or ppl < best_ppl:
                best_ppl = ppl
                best_checkpoint = fields[0]
    return best_checkpoint, best_ppl
예제 #11
0
def main(argv):
    del argv

    angles_to_sweep = np.deg2rad(np.linspace(-90, 180, num=19, endpoint=True))
    weights_to_sweep = np.stack(
        [np.sin(angles_to_sweep),
         np.cos(angles_to_sweep)], axis=-1)
    weights_to_sweep /= np.sum(np.maximum(weights_to_sweep, 0.0),
                               axis=-1,
                               keepdims=True)
    weights_to_sweep = np.clip(weights_to_sweep, -1000, 1000)
    tf.logging.info(weights_to_sweep)

    all_returns = []
    for keyboard_path in FLAGS.keyboard_paths:
        returns = evaluate_keyboard(keyboard_path, weights_to_sweep)
        all_returns.append(returns)

    print("Results:")
    print(np.mean(all_returns, axis=-1).T)

    if FLAGS.output_path:
        with gfile.GFile(FLAGS.output_path, "w") as file:
            writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
            writer.writerow(["angle", "return", "idx"])
            for idx, returns in enumerate(all_returns):
                for row in np.array(returns).T.tolist():
                    assert len(angles_to_sweep) == len(row)
                    for ang, val in zip(angles_to_sweep, row):
                        ang = "{:.4g}".format(ang)
                        val = "{:.4g}".format(val)
                        writer.writerow([ang, val, idx])
예제 #12
0
    def setUp(self):
        super(EvalTest, self).setUp()
        self.task = 'ic'
        self.tmp_dir = self.create_tempdir().full_path + '/'

        # Construct sample KB.
        ent0, rel0, ent1, rel1, rel2 = 0, 0, 1, 1, 2
        answers = {((ent0, (rel0, )), (ent1, (rel1, )), rel2): {3, 4, 5}}
        hard_answers = {((ent0, (rel0, )), (ent1, (rel1, )), rel2): {5}}
        kb = ['ent0\trel0\tent1\n', 'ent2\trel1\tent3\n', 'ent4\trel2\tent5\n']
        ent2ind = {'ent%d' % i: i for i in range(6)}
        rel2ind = {'rel%d' % i: i for i in range(3)}
        ind2ent = {ind: ent for ent, ind in ent2ind.items()}
        ind2rel = {ind: rel for rel, ind in rel2ind.items()}

        # Dump to temp files.
        pickle.dump(
            answers,
            gfile.GFile(self.tmp_dir + 'test_ans_%s.pkl' % self.task, 'wb'))
        pickle.dump(
            hard_answers,
            gfile.GFile(self.tmp_dir + 'test_ans_%s_hard.pkl' % self.task,
                        'wb'))
        pickle.dump(ent2ind, gfile.GFile(self.tmp_dir + 'ent2ind.pkl', 'wb'))
        pickle.dump(rel2ind, gfile.GFile(self.tmp_dir + 'rel2ind.pkl', 'wb'))
        pickle.dump(ind2ent, gfile.GFile(self.tmp_dir + 'ind2ent.pkl', 'wb'))
        pickle.dump(ind2rel, gfile.GFile(self.tmp_dir + 'ind2rel.pkl', 'wb'))
        with gfile.GFile(self.tmp_dir + 'kb.txt', 'w') as f_kb:
            for line in kb:
                f_kb.write(line)
예제 #13
0
def write_returns_to_file(path, returns):
  """Write returns to file."""

  with gfile.GFile(path, "w") as file:
    writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
    writer.writerow(["episode", "train"] +
                    [f"eval_{idx}" for idx in range(len(returns[0]["eval"]))])
    for row in returns:
      writer.writerow([row["episode"], row["train"]] + row["eval"])
예제 #14
0
def load_json(path, scrub = False):
  logging.info(f'Reading json from {path} into memory...')
  with gfile.GFile(path) as f:
    if scrub:
      data = json.loads(_scrub_json(f.read()))
    else:
      data = json.load(f)
  logging.info(f'Successfully loaded json data from {path} into memory.')
  return data
예제 #15
0
def get_raw_data(data_path, dataset, truncate_vocab=20000):
  """Load raw data from data directory "data_path".

  Reads text files, converts strings to integer ids,
  and performs mini-batching of the inputs.

  Args:
    data_path: string path to the directory where simple-examples.tgz has been
      extracted.
    dataset: one of ["emnlp2017-prep", "coco-prep"]
    truncate_vocab: int, number of words to keep in the vocabulary.

  Returns:
    tuple (train_data, valid_data, vocabulary) where each of the data
    objects can be passed to iterator.

  Raises:
    ValueError: dataset not in ["emnlp2017"].
  """
  if dataset not in FILENAMES:
    raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
        dataset, list(FILENAMES.keys())))
  train_file, valid_file, _ = FILENAMES[dataset]

  train_path = os.path.join(data_path, train_file)
  valid_path = os.path.join(data_path, valid_file)

  with gfile.GFile(train_path, "r") as infile:
    data_train = read_lines(infile)
  with gfile.GFile(valid_path, "r") as infile:
    data_valid = read_lines(infile)

  word_to_id = _build_vocab(data_train)
  logging.info("Full vocab length: %d", len(word_to_id))
  # Assume the vocab is sorted by frequency.
  word_to_id_truncated = {
      k: v for k, v in word_to_id.items() if v < truncate_vocab
  }
  logging.info("Truncated vocab length: %d", len(word_to_id_truncated))

  train_data = _integerize(data_train, word_to_id_truncated, dataset)
  valid_data = _integerize(data_valid, word_to_id_truncated, dataset)
  return train_data, valid_data, word_to_id_truncated
예제 #16
0
def write_dataset(dataset, save_path):
    """Saves the given dataset into the given location."""
    if not dataset:
        logging.info('No dataset to write.')
        return
    logging.info('Writing dataset to %s', save_path)
    for split_name, list_of_input_output_pairs in dataset.items():
        folder_name = os.path.join(save_path, split_name)
        if not os.path.exists(folder_name):
            os.makedirs(folder_name)
        encode_name = os.path.join(folder_name, '%s_encode.txt' % split_name)
        decode_name = os.path.join(folder_name, '%s_decode.txt' % split_name)
        with gfile.GFile(encode_name,
                         'w') as encode_f, gfile.GFile(decode_name,
                                                       'w') as decode_f:
            for pair in list_of_input_output_pairs:
                encode_f.write(pair[0] + '\n')
                decode_f.write(pair[1] + '\n')
    logging.info('Dataset written to %s', save_path)
def generate_kb(dataset):
    """Recover KB from processed pkl file."""
    ind2rel = pickle.load(
        gfile.GFile(FLAGS.query2box_dir + '%s/ind2rel.pkl' % dataset, 'rb'))
    rel2ind = {v: k for k, v in ind2rel.items()}
    ind2ent = pickle.load(
        gfile.GFile(FLAGS.query2box_dir + '%s/ind2ent.pkl' % dataset, 'rb'))
    ent2ind = {v: k for k, v in ind2ent.items()}

    f_out = open(FLAGS.output_dir + '%s/kb.txt' % dataset, 'w')
    f_out_test = open(FLAGS.output_dir + '%s/kb_test.txt' % dataset, 'w')
    f_ind2rel_out = open(FLAGS.output_dir + '%s/ind2rel.pkl' % dataset, 'wb')
    f_rel2ind_out = open(FLAGS.output_dir + '%s/rel2ind.pkl' % dataset, 'wb')
    f_ind2ent_out = open(FLAGS.output_dir + '%s/ind2ent.pkl' % dataset, 'wb')
    f_ent2ind_out = open(FLAGS.output_dir + '%s/ent2ind.pkl' % dataset, 'wb')

    all_facts = set()
    all_facts_test = set()
    for split in ['train', 'valid', 'test']:
        data_1c_file = FLAGS.query2box_dir + '%s/%s_ans_1c.pkl' % (dataset,
                                                                   split)
        data_1c = pickle.load(open(data_1c_file, 'rb'))
        for query, answers in data_1c.items():
            subj, rels = query[0]
            rel = rels[0]
            for obj in answers:
                all_facts_test.add((ind2ent[subj], ind2rel[rel], ind2ent[obj]))
                if split != 'test':
                    all_facts.add((ind2ent[subj], ind2rel[rel], ind2ent[obj]))

    for subj, rel, obj in all_facts:
        if not rel.endswith('_reverse'):
            f_out.write('%s\t%s\t%s\n' % (subj, rel, obj))

    for subj, rel, obj in all_facts_test:
        if not rel.endswith('_reverse'):
            f_out_test.write('%s\t%s\t%s\n' % (subj, rel, obj))

    pickle.dump(ind2rel, f_ind2rel_out)
    pickle.dump(rel2ind, f_rel2ind_out)
    pickle.dump(ind2ent, f_ind2ent_out)
    pickle.dump(ent2ind, f_ent2ind_out)
예제 #18
0
def write_ppl_results(checkpoint_dir,
                      checkpoint_name,
                      valid_ppl,
                      valid_tokens,
                      test_ppl,
                      test_tokens,
                      eval_filename=None):
    """Write evaluation results to disk."""
    if eval_filename is None:
        eval_filename = 'ppl.csv'
    to_write = ",".join(
        map(str,
            [checkpoint_name, valid_ppl, valid_tokens, test_ppl, test_tokens]))
    eval_filepath = os.path.join(checkpoint_dir, eval_filename)
    previous_eval_content = ""
    if gfile.exists(eval_filepath):
        with gfile.GFile(eval_filepath, "r") as f:
            previous_eval_content = f.read()
    with gfile.GFile(eval_filepath, "w") as f:
        f.write(previous_eval_content + to_write + "\n")
예제 #19
0
def load_scan(path):
  """Read original scan task data and convert into CFQ-style json format."""
  logging.info(f'Reading SCAN tasks from {path}.')
  def parse(infile):
    for line in infile.read().split('\n'):
      if not line.startswith('IN: '):
        continue
      commands, actions = line[len('IN: '):].strip().split(' OUT: ', 1)
      yield {'questionPatternModEntities': commands,
             'sparqlPatternModEntities': actions}
  return list(parse(gfile.GFile(path)))
예제 #20
0
def dump_trajectory(output_dir, epoch, env_id, temperature, random_string,
                    trajs):
  """Write the trajectory to disk."""

  assert 1 == len(trajs)
  traj = trajs[0]

  trajectory_file_name = trajectory.TRAJECTORY_FILE_FORMAT.format(
      epoch=epoch, env_id=env_id, temperature=temperature, r=random_string)

  with gfile.GFile(os.path.join(output_dir, trajectory_file_name), 'w') as f:
    trajectory.get_pickle_module().dump(traj, f)
예제 #21
0
    def construct_query2box(self, name):
        """Load query2box data into a list of examples."""
        task = name.split('_')[-1]
        # we don't train our model in the same way as query2box,
        # so we won't load their training data.
        test_data = pickle.load(
            gfile.GFile(self.root_dir + 'test_ans_%s.pkl' % task, 'rb'))

        converted_test_data = list()
        for query, unused_answers in test_data.items():
            converted_test_data.append(query)

        return None, converted_test_data
예제 #22
0
 def __init__(self, task, root_dir, data_loader):
     self.task = task
     self.data_loader = data_loader
     # Load answers and hard_answers. Hard answers are defined as
     # answers that can only be inferred from the test kb that are
     # not exposed in train kb. Test kb is a superset of train kb.
     answer_file = root_dir + 'test_ans_%s.pkl' % task
     hard_answer_file = root_dir + 'test_ans_%s_hard.pkl' % task
     self.answers = pickle.load(gfile.GFile(answer_file, 'rb'))
     self.hard_answers = pickle.load(gfile.GFile(hard_answer_file, 'rb'))
     self.q2b_entity2id = pickle.load(
         gfile.GFile(root_dir + 'ent2ind.pkl', 'rb'))
     self.q2b_relation2id = pickle.load(
         gfile.GFile(root_dir + 'rel2ind.pkl', 'rb'))
     # Declare metrics.
     self.total = 0.0
     self.metrics = {
         'hits@1': 0.0,
         'hits@3': 0.0,
         'hits@10': 0.0,
         'mrr': 0.0
     }
예제 #23
0
def write_token_vocab(words, save_path, problem='cfq'):
    """"Writes token vocabulary from @words to @save_path."""
    # Sort tokens by frequency and then lexically to break ties.
    words_with_counts = words.most_common()
    words_with_counts.sort(key=lambda x: (x[1], x[0]), reverse=True)
    vocab_path = os.path.join(save_path, 'vocab.%s.tokens' % problem)

    with gfile.GFile(vocab_path, 'w') as f:
        # Tensor2tensor needs these additional tokens.
        f.write('<pad>\n<EOS>\n<OOV>\n')
        for word, _ in words_with_counts:
            f.write(word + '\n')
    logging.info('Token vocabulary written to %s (%s distinct tokens).',
                 vocab_path, len(words))
예제 #24
0
  def _savefile(self, path, img, zoom, grid_height):

    data = make_grid(img, zoom=zoom, grid_height=grid_height)

    # Writing takes time, and opening a file for writing erases its contents,
    # so it's better to write to a temporary file and then copy the results.
    dirname, basename = osp.dirname(path), osp.basename(path)
    temp_file = osp.join(dirname, '.' + basename)
    with gfile.GFile(temp_file, 'wb') as f:
      img = Image.fromarray(data)
      img.save(f, format='PNG')

    gfile.copy(temp_file, path, overwrite=True)
    gfile.remove(temp_file)
예제 #25
0
    def init_from_file(self, file_name):
        """Initializes this layer and its sublayers from a file.

    We assume that the file is a pickled dictionary that contains the fields
    'weights' and 'state' with structures corresponding to this layers weights
    and state. Note that the pickled dictionary is allowed to contain other
    fields too, but these two are required to init.

    Args:
      file_name: the name of the file to initialize from.
    """
        with gfile.GFile(file_name, 'rb') as f:
            dictionary = pickle.load(f)
        self.weights = dictionary['weights']
        self.state = dictionary['state']
예제 #26
0
def get_accuracy_result(questions_path, golden_answers_path,
                        inferred_answers_path):
    """Collect accuracy results from input files."""
    questions = gfile.GFile(questions_path).readlines()
    golden_answers = gfile.GFile(golden_answers_path).readlines()
    inferred_answers = gfile.GFile(inferred_answers_path).readlines()

    result = AccuracyResult(total_lines=len(questions),
                            matches=[],
                            mismatches=[],
                            inferred_answers_path=inferred_answers_path)
    if len(set(
        (len(questions), len(golden_answers), len(inferred_answers)))) > 1:
        raise ValueError(
            'Not writing accuracy results: Input files have different lengths\n'
            'Questions: %s, golden answers: %s, inferred answers: %s' %
            (len(questions), len(golden_answers), len(inferred_answers)))
    for question, golden, inferred in zip(questions, golden_answers,
                                          inferred_answers):
        if inferred == golden:
            result.matches.append((question, golden))
        else:
            result.mismatches.append((question, golden, inferred))
    return result
예제 #27
0
    def print_metrics(self):
        """Print metrics to stdout."""
        print('task: ', self.task)

        # Print to stdout.
        for k, v in self.metrics.items():
            print(k, v / self.total)

        # Print to log files if exist.
        if FLAGS.metrics_file is not None:
            with gfile.GFile(FLAGS.metrics_file, 'w') as f_out:
                f_out.write('task: %s\n' % self.task)
                f_out.write('top_k: %d\n' % FLAGS.intermediate_top_k)
                f_out.write('cm_width: %d\n' % FLAGS.cm_width)
                for k, v in self.metrics.items():
                    f_out.write('%s: %f\n' % (k, v / self.total))
예제 #28
0
    def load_vocab(self, vocab_file):
        """Load vocab from file.

    Args:
      vocab_file: vocab file

    Returns:
      a mapping from word to id
    """
        word2id = dict()
        with gfile.GFile(vocab_file, 'r') as f_in:
            for line in tqdm(f_in):
                word = json.loads(line)
                assert word not in word2id
                word2id[word] = len(word2id)
        return word2id
예제 #29
0
def parse_param_file(param_file):
    """Parse parameter file for parameters."""
    with gfile.GFile(param_file, 'r') as fh:
        lines = fh.readlines()
    d = {}
    for l in lines:
        l = l.rstrip('\n')
        splits = l.split(':')
        key = splits[0]
        val_ = splits[1].strip()
        if not val_:
            val = ''
        else:
            try:
                val = ast.literal_eval(val_)
            except (ValueError, SyntaxError):
                val = str(val_)
        d[key] = val
    return d
예제 #30
0
def write_accuracy_result(result, output_path, print_output=False):
    """Writes the accuracy results to a text file."""
    if not result:
        return
    accuracy = result.get_accuracy()
    summary = f'Accuracy on {result.inferred_answers_path} is {accuracy}'
    with gfile.GFile(output_path, 'w') as f:
        f.write(f'{summary}\n')
        if result.mismatches:
            f.write('\n==========WRONG==========\n')
        for question, golden, inferred in result.mismatches:
            f.write(f'Q: {question}Gold: {golden}Inferred: {inferred}\n')
        if result.matches:
            f.write('\n==========CORRECT==========\n')
        for question, golden in result.matches:
            f.write(f'Q: {question}Gold/Inferred: {golden}\n')
    if print_output:
        print(f'Evaluation result written to {output_path}\n')
        print(summary)