コード例 #1
0
def write_completion(data, output_file_data_src, output_file_data_tar,
                     output_file_kb):
    """This function write both kb and main data into the files."""
    f_data_src = gfile.Open(output_file_data_src, 'w')
    f_data_tar = gfile.Open(output_file_data_tar, 'w')
    f_kb = gfile.Open(output_file_kb, 'w')
    for entry in data:
        bd1 = entry['boundaries1'].split(' ')
        bd2 = entry['boundaries2'].split(' ')
        start = bd1[0:len(bd1) // 2] + bd2[0:len(bd2) // 2]
        end = bd1[len(bd1) // 2:] + bd2[len(bd2) // 2:]
        # random_turn = random.randint(0, len(start) - 1)
        for random_turn in range(len(start)):
            f_kb.write(flatten_json(entry['kb']) + '\n')
            # print len(start),len(end),len(bd),random_turn
            turn_start = int(start[random_turn])
            turn_end = int(end[random_turn])
            dialogue_split = entry['dialogue'].split(' ')
            # print turn_start,turn_end
            dialogue_src = dialogue_split[0:turn_start + 1]
            dialogue_tar = dialogue_split[turn_start + 1:turn_end + 1]
            src_arr = [entry['intent'], ' '.join(dialogue_src)]
            f_data_src.write('|'.join(src_arr) + '\n')
            tar_arr = [entry['action'], ' '.join(dialogue_tar)]
            f_data_tar.write('|'.join(tar_arr) + '\n')

    f_data_src.close()
    f_data_tar.close()
    f_kb.close()
コード例 #2
0
def write_vocabulary(output_file, output_all_vocab_file, word_frequency,
                     frequency_cutoff, keep_non_ascii):
    special_chars = [
        unk_token, start_of_turn1, start_of_turn2, end_of_dialogue
    ]
    new_word_frequency = set([])
    # if output_filei not None, we write to file, otherwise we don't.
    if output_file:
        f = gfile.Open(output_file, 'w')
    else:
        f = None

    # first one should always be unktoken
    for special_char in special_chars:
        if f:
            f.write(special_char + '\n')
    for key in word_frequency:
        # We write to the vocabulary only when the key is not empty.
        # Otherwise tensorflow will complain.
        if word_frequency[key] >= frequency_cutoff and (
                key not in special_chars) and is_ascii(key,
                                                       keep_non_ascii) and key:
            if f:
                f.write(key + '\n')
            new_word_frequency.add(key)

    if f:
        f.close()
    # all vocab
    if output_all_vocab_file:
        with gfile.Open(output_all_vocab_file, 'w') as f2:
            f2.write(str(word_frequency))
    return new_word_frequency
コード例 #3
0
def main(unused_argv):

    # Get one-hot encoding.
    mol_weights = pd.Series(_MOL_WEIGHTS)
    alphabet = [k for k in mol_weights.keys() if not k.startswith(_GROUP)]
    alphabet = sorted(alphabet)
    one_hot_encoding = pd.get_dummies(alphabet).astype(int).to_dict(
        orient='list')

    with gfile.Open(FLAGS.input_data) as inputf:
        input_data = pd.read_csv(inputf, sep=',')
    input_data.rename(columns={
        FLAGS.sequence_col: _MOD_SEQUENCE,
        FLAGS.charge_col: _CHARGE,
        FLAGS.fragmentation_col: _FRAGMENTATION,
        FLAGS.analyzer_col: _MASS_ANALYZER
    },
                      inplace=True)

    metadata, _ = preprocess_peptides(input_data, FLAGS.clean_peptides)
    metadata = metadata.reset_index()

    check_inputs(metadata, alphabet)

    # length.
    json_inputs = generate_json_inputs(metadata, one_hot_encoding)
    with gfile.Open(os.path.join(FLAGS.output_data_dir, 'input.json'),
                    'w') as outf:
        for json_input in json_inputs:
            outf.write(json.dumps(json_input) + '\n')
    with gfile.Open(os.path.join(FLAGS.output_data_dir, 'metadata.tsv'),
                    'w') as outf:
        metadata.to_csv(outf, sep='\t')
コード例 #4
0
ファイル: test_ignite.py プロジェクト: ZJM-TECH/tensorflow_io
    def test_copy(self):
        """Test copy.

    """
        # Setup and check preconditions.
        src_file_name = "igfs:///test_copy/1"
        dst_file_name = "igfs:///test_copy/2"
        self.assertFalse(gfile.Exists(src_file_name))
        self.assertFalse(gfile.Exists(dst_file_name))
        with gfile.Open(src_file_name, mode="w") as w:
            w.write("42")
        self.assertTrue(gfile.Exists(src_file_name))
        self.assertFalse(gfile.Exists(dst_file_name))
        # Copy file.
        gfile.Copy(src_file_name, dst_file_name)
        # Check that files are identical.
        self.assertTrue(gfile.Exists(src_file_name))
        self.assertTrue(gfile.Exists(dst_file_name))
        with gfile.Open(dst_file_name, mode="r") as r:
            data_v = r.read()
        self.assertEqual("42", data_v)
        # Remove file.
        gfile.Remove(src_file_name)
        gfile.Remove(dst_file_name)
        # Check that file was removed.
        self.assertFalse(gfile.Exists(src_file_name))
        self.assertFalse(gfile.Exists(dst_file_name))
コード例 #5
0
def main(unused_argv):
    tokenizer = FullTokenizer(FLAGS.tokenizer_vocabulary)

    print('Loading ' + str(FLAGS.dataset_name) + ' dataset from ' +
          FLAGS.input_filepath)

    # The debugging file saves all of the processed SQL queries.
    debugging_file = gfile.Open(
        os.path.join('/'.join(FLAGS.output_filepath.split('/')[:-1]),
                     FLAGS.dataset_name + '_'.join(FLAGS.splits) +
                     '_gold.txt'), 'w')

    # The output file will save a sequence of string-serialized JSON objects, one
    # line per object.
    output_file = gfile.Open(os.path.join(FLAGS.output_filepath), 'w')

    if FLAGS.dataset_name.lower() == 'spider':
        num_examples_created, num_examples_failed = process_spider(
            output_file, debugging_file, tokenizer)
    elif FLAGS.dataset_name.lower() == 'wikisql':
        num_examples_created, num_examples_failed = process_wikisql(
            output_file, debugging_file, tokenizer)
    else:
        num_examples_created, num_examples_failed = process_michigan_datasets(
            output_file, debugging_file, tokenizer)

    print('Wrote %s examples, could not annotate %s examples.' %
          (num_examples_created, num_examples_failed))
    debugging_file.write('Wrote %s examples, could not annotate %s examples.' %
                         (num_examples_created, num_examples_failed))
    debugging_file.close()
    output_file.close()
コード例 #6
0
 def __init__(self, vocab_file, embedding_file, normalize_embeddings=True):
     with gfile.Open(embedding_file, 'rb') as f:
         self.embedding_mat = np.load(f)
     if normalize_embeddings:
         self.embedding_mat = self.embedding_mat / np.linalg.norm(
             self.embedding_mat, axis=1, keepdims=True)
     with gfile.Open(vocab_file, 'r') as f:
         tks = json.load(f)
     self.vocab = dict(zip(tks, range(len(tks))))
コード例 #7
0
def write_self_play(data, output_file_data, output_file_kb):
    f_data = gfile.Open(output_file_data, 'w')
    f_kb = gfile.Open(output_file_kb, 'w')
    for entry in data:
        f_kb.write(flatten_json(entry['kb']) + '\n')
        new_arr = [entry['intent'], entry['expected_action']
                   ]  # intent and action are both needed
        f_data.write('|'.join(new_arr) + '\n')
    f_data.close()
    f_kb.close()
コード例 #8
0
ファイル: helpers.py プロジェクト: LONG-9621/Stackedcapsule
def load_pickle(pickle_file):
    """Load a pickle file (py2/3 compatible)."""
    try:
        with gfile.Open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f)
    except UnicodeDecodeError as e:
        with gfile.Open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f, encoding='latin1')
    except Exception as e:
        print('Unable to load {}: {}'.format(pickle_file, e))
        raise
    return pickle_data
コード例 #9
0
def collect_programs():
  saved_programs = {}
  for i in xrange(FLAGS.id_start, FLAGS.id_end):
    with gfile.Open(
        os.path.join(get_experiment_dir(), 'program_shard_{}-{}.json'.format(
            i, FLAGS.n_epoch)), 'r') as f:
      program_shard = json.load(f)
      saved_programs.update(program_shard)
  saved_program_path = os.path.join(get_experiment_dir(), 'saved_programs.json')
  with gfile.Open(saved_program_path, 'w') as f:
    json.dump(saved_programs, f)
  print 'saved programs are aggregated in {}'.format(saved_program_path)
コード例 #10
0
def get_nl_sql_pairs(filepath, splits, with_dbs=False):
  """Gets pairs of natural language and corresponding gold SQL for Michigan."""
  with gfile.Open(filepath) as infile:
    data = json.load(infile)

  pairs = list()

  tag = '[' + filepath.split('/')[-1].split('.')[0] + ']'
  print('Getting examples with tag ' + tag)

  # The UMichigan data is split by anonymized queries, where values are
  # anonymized but table/column names are not. However, our experiments are
  # performed on the original splits of the data.
  for query in data:
    # Take the first SQL query only. From their Github documentation:
    # "Note - we only use the first query, but retain the variants for
    #  completeness"
    anonymized_sql = query['sql'][0]

    # It's also associated with a number of natural language examples, which
    # also contain anonymous tokens. Save the de-anonymized utterance and query.
    for example in query['sentences']:
      if example['question-split'] not in splits:
        continue

      nl = example['text']
      sql = anonymized_sql

      # Go through the anonymized values and replace them in both the natural
      # language and the SQL.
      #
      # It's very important to sort these in descending order. If one is a
      # substring of the other, it shouldn't be replaced first lest it ruin the
      # replacement of the superstring.
      for variable_name, value in sorted(
          example['variables'].items(), key=lambda x: len(x[0]), reverse=True):
        if not value:
          # TODO(alanesuhr) While the Michigan repo says to use a - here, the
          # thing that works is using a % and replacing = with LIKE.
          #
          # It's possible that I should remove such clauses from the SQL, as
          # long as they lead to the same table result. They don't align well
          # to the natural language at least.
          #
          # See: https://github.com/jkkummerfeld/text2sql-data/tree/master/data
          value = '%'

        nl = nl.replace(variable_name, value)
        sql = sql.replace(variable_name, value)

      # In the case that we replaced an empty anonymized value with %, make it
      # compilable new allowing equality with any string.
      sql = sql.replace('= "%"', 'LIKE "%"')

      if with_dbs:
        pairs.append((nl, sql, example['table-id']))
      else:
        pairs.append((nl, sql))

  return pairs
コード例 #11
0
def main(argv):
    del argv  # unused.
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()

    environment = DockingRewardMolecule(
        discount_factor=hparams.discount_factor,
        atom_types=set(hparams.atom_types),
        init_mol=FLAGS.start_molecule,
        allow_removal=hparams.allow_removal,
        allow_no_modification=hparams.allow_no_modification,
        allow_bonds_between_rings=hparams.allow_bonds_between_rings,
        allowed_ring_sizes=set(hparams.allowed_ring_sizes),
        max_steps=hparams.max_steps_per_episode)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=1.0)

    run_dqn.run_training(hparams=hparams, environment=environment, dqn=dqn)

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
コード例 #12
0
 def _inner_iter(self, fpath):
     with gfile.Open(fpath, 'r') as fh:
         rest_buffer = []
         aware_headers = True
         read_finished = False
         while not read_finished:
             dict_reader, rest_buffer, read_finished = \
                     self._make_csv_dict_reader(fh, rest_buffer,
                                                aware_headers)
             aware_headers = False
             if self._headers is None:
                 self._headers = dict_reader.fieldnames
             elif self._headers != dict_reader.fieldnames:
                 logging.fatal("the schema of %s is %s, mismatch "\
                               "with previous %s", fpath,
                               self._headers, dict_reader.fieldnames)
                 traceback.print_stack()
                 os._exit(-1)  # pylint: disable=protected-access
             self._validator.check_csv_header(self._headers)
             # check invalid character for headers
             for raw in dict_reader:
                 if not self._validator.check_csv_record(
                         raw, len(self._headers)):
                     continue
                 yield CsvItem(raw)
コード例 #13
0
def load_wikisql_tables(filepath):
    """Loads the WikiSQL tables from a path and reformats as the format."""
    dbs = dict()
    with gfile.Open(filepath) as infile:
        tables = [json.loads(line) for line in infile if line]

    for table in tables:
        db_dict = dict()
        table_name = table[
            'section_title'] if 'section_title' in table and table[
                'section_title'] else (
                    table['name'] if 'name' in table else table['page_title'])

        table_name = normalize_entities(table_name)

        db_dict[table_name] = list()
        for column_name, column_type in zip(table['header'], table['types']):
            if column_type == 'real':
                column_type = 'number'
            assert column_type in {'text', 'number'}, column_type
            column_name = normalize_entities(column_name)

            db_dict[table_name].append({
                'field name': column_name,
                'is primary key': False,
                'is foreign key': False,
                'type': column_type
            })
        if table['id'] not in dbs:
            dbs[table['id']] = db_dict

    return dbs
コード例 #14
0
def get_optimized_mols(model_dir, ckpt=80000):
    """Get optimized Molecules.

  Args:
    model_dir: String. model directory.
    ckpt: the checkpoint to load.

  Returns:
    List of 800 optimized molecules
  """
    hparams_file = os.path.join(model_dir, 'config.json')
    with gfile.Open(hparams_file, 'r') as f:
        hp_dict = json.load(f)
        hparams = deep_q_networks.get_hparams(**hp_dict)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=0.0)

    tf.reset_default_graph()
    optimized_mol = []
    with tf.Session() as sess:
        dqn.build()
        model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints)
        model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt))
        for mol in all_mols:
            logging.info('Eval: %s', mol)
            environment = molecules_mdp.Molecule(
                atom_types=set(hparams.atom_types),
                init_mol=mol,
                allow_removal=hparams.allow_removal,
                allow_no_modification=hparams.allow_no_modification,
                allow_bonds_between_rings=hparams.allow_bonds_between_rings,
                allowed_ring_sizes=set(hparams.allowed_ring_sizes),
                max_steps=hparams.max_steps_per_episode,
                record_path=True)
            environment.initialize()
            if hparams.num_bootstrap_heads:
                head = np.random.randint(hparams.num_bootstrap_heads)
            else:
                head = 0
            for _ in range(hparams.max_steps_per_episode):
                steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
                valid_actions = list(environment.get_valid_actions())
                observations = np.vstack([
                    np.append(deep_q_networks.get_fingerprint(act, hparams),
                              steps_left) for act in valid_actions
                ])
                action = valid_actions[dqn.get_action(observations,
                                                      head=head,
                                                      update_epsilon=0.0)]
                environment.step(action)
            optimized_mol.append(environment.get_path())
    return optimized_mol
コード例 #15
0
 def __init__(self, latent_factor_indices=None):
     DSprites.__init__(self, latent_factor_indices)
     self.data_shape = [64, 64, 3]
     with gfile.Open(SCREAM_PATH, "rb") as f:
         scream = PIL.Image.open(f)
         scream.thumbnail((350, 274, 3))
         self.scream = np.array(scream) * 1. / 255.
コード例 #16
0
    def setUp(self) -> None:
        logging.getLogger().setLevel(logging.DEBUG)
        self._data_portal_name = 'test_data_portal_job_manager'

        self._kvstore = DBClient('etcd', True)
        self._portal_input_base_dir = './portal_input_dir'
        self._portal_output_base_dir = './portal_output_dir'
        self._raw_data_publish_dir = 'raw_data_publish_dir'
        if gfile.Exists(self._portal_input_base_dir):
            gfile.DeleteRecursively(self._portal_input_base_dir)
        gfile.MakeDirs(self._portal_input_base_dir)

        self._data_fnames = ['1001/{}.data'.format(i) for i in range(100)]
        self._data_fnames_without_success = \
            ['1002/{}.data'.format(i) for i in range(100)]
        self._csv_fnames = ['1003/{}.csv'.format(i) for i in range(100)]
        self._unused_fnames = ['{}.xx'.format(100)]
        self._all_fnames = self._data_fnames + \
                           self._data_fnames_without_success + \
                           self._csv_fnames + self._unused_fnames

        all_fnames_with_success = ['1001/_SUCCESS'] + ['1003/_SUCCESS'] +\
                                  self._all_fnames
        for fname in all_fnames_with_success:
            fpath = os.path.join(self._portal_input_base_dir, fname)
            gfile.MakeDirs(os.path.dirname(fpath))
            with gfile.Open(fpath, "w") as f:
                f.write('xxx')
コード例 #17
0
ファイル: test_ignite.py プロジェクト: zxshinxz/io
  def test_list_directory(self):
    """Test list directory.

    """
    # Setup and check preconditions.
    gfile.MkDir(self.prefix() + ":///test_list_directory")
    gfile.MkDir(self.prefix() + ":///test_list_directory/2")
    gfile.MkDir(self.prefix() + ":///test_list_directory/4")
    dir_name = self.prefix() + ":///test_list_directory"
    file_names = [
        self.prefix() + ":///test_list_directory/1",
        self.prefix() + ":///test_list_directory/2/3"
    ]
    ch_dir_names = [
        self.prefix() + ":///test_list_directory/4",
    ]
    for file_name in file_names:
      with gfile.Open(file_name, mode="w") as w:
        w.write("")
    for ch_dir_name in ch_dir_names:
      gfile.MkDir(ch_dir_name)
    ls_expected_result = file_names + ch_dir_names
    # Get list of files in directory.
    ls_result = gfile.ListDirectory(dir_name)
    # Check that list of files is correct.
    self.assertEqual(len(ls_expected_result), len(ls_result))
    for e in ["1", "2", "4"]:
      self.assertTrue(e in ls_result, msg="Result doesn't contain '%s'" % e)
コード例 #18
0
def main(unused_argv):
    # Load the examples
    vocabulary = set()
    valid_filenames = [
        filename for filename in FLAGS.input_filenames if filename
    ]
    for filename in valid_filenames:
        with open(os.path.join(FLAGS.data_dir, filename)) as infile:
            for line in infile:
                if not line:
                    continue

                symbols = get_symbol(line)
                new_symbols = [
                    symbol for symbol in symbols if symbol not in vocabulary
                ]
                if new_symbols:
                    print(new_symbols)
                for symbol in symbols:
                    vocabulary.add(symbol)

    print("Writing vocabulary of size %d to %s" %
          (len(vocabulary), FLAGS.output_path))
    with gfile.Open(FLAGS.output_path, "w") as ofile:
        ofile.write("\n".join(list(vocabulary)))
コード例 #19
0
def main(argv):
    del argv
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()

    hparams.override_from_dict(
        {'max_steps_per_episode': FLAGS.max_steps_per_episode})

    environment = Molecule(atom_types=set(hparams.atom_types),
                           init_mol=FLAGS.start_molecule,
                           allow_removal=hparams.allow_removal,
                           allow_no_modification=hparams.allow_no_modification,
                           max_steps=hparams.max_steps_per_episode)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
コード例 #20
0
ファイル: dfs_client.py プロジェクト: piiswrong/fedlearner
 def get_prefix_kvs(self, prefix, ignore_prefix=False):
     kvs = []
     target_path = self._generate_path(prefix, with_meta=False)
     cur_paths = [target_path]
     children_paths = []
     while cur_paths:
         for path in cur_paths:
             filenames = []
             try:
                 if gfile.IsDirectory(path):
                     filenames = gfile.ListDirectory(path)
             except Exception as e:  # pylint: disable=broad-except
                 logging.warning("get prefix kvs %s failed, "
                                 " reason: %s", path, str(e))
                 break
             for filename in sorted(filenames):
                 file_path = "/".join([path, filename])
                 if gfile.IsDirectory(file_path):
                     children_paths.append(file_path)
                 else:
                     if ignore_prefix and path == target_path:
                         continue
                     nkey = self.normalize_output_key(
                         path, self._base_dir).encode()
                     with gfile.Open(file_path, 'rb') as file:
                         kvs.append((nkey, file.read()))
         cur_paths = children_paths
         children_paths = []
     return kvs
コード例 #21
0
def main(unused_argv):

    metadata = pd.read_csv(gfile.Open(FLAGS.metadata_file), sep='\t')

    # Read DeepMass:Prism outputs and merge with metadata.
    outputs = []
    for filen in gfile.Glob(FLAGS.input_data_pattern):
        with gfile.Open(filen) as infile:
            if FLAGS.batch_prediction:
                out_df = pd.read_json(infile, lines=True)
            else:
                out_df = json.load(infile)
                out_df = pd.DataFrame(out_df['predictions'])
            out_df = out_df.merge(metadata,
                                  left_on='key',
                                  right_on='index',
                                  how='left')
            outputs.append(out_df)
    outputs = pd.concat(outputs)
    outputs = outputs.apply(reformat_outputs,
                            args=(int(FLAGS.label_dim), FLAGS.neutral_losses),
                            axis=1)

    # Read additional features.
    if FLAGS.add_feature_names is not None:
        outputs_drip = []
        for filen in gfile.Glob(FLAGS.add_input_data_pattern):
            with gfile.Open(filen) as infile:
                if FLAGS.batch_prediction:
                    out_df = pd.read_json(infile, lines=True)
                    out_df = pd.DataFrame(out_df['outputs'].tolist(),
                                          columns=FLAGS.add_feature_names,
                                          index=out_df['key'])
                else:
                    pass
                outputs_drip.append(out_df)
        outputs_drip = pd.concat(outputs_drip)
        outputs = outputs.merge(outputs_drip,
                                how='left',
                                left_on='key',
                                right_index=True)

    # Write to a file.
    with gfile.Open(os.path.join(FLAGS.output_data_dir, 'outputs.tsv'),
                    'w') as outf:
        outputs.to_csv(outf, sep='\t', index=False)
コード例 #22
0
ファイル: simulator_main.py プロジェクト: yyht/airdialogue
def main(FLAGS):
    if FLAGS.verbose:
        print("Number of samples to generate: ", FLAGS.num_samples)
        print("Output_kb: ", FLAGS.output_kb)
        print("Output_data: ", FLAGS.output_data)

    num_samples = FLAGS.num_samples
    cg = context_generator_lib.ContextGenerator(
        num_candidate_airports=FLAGS.num_candidate_airports,
        book_window=FLAGS.book_window,
        num_db_record=FLAGS.num_db_record,
        firstname_file=FLAGS.firstname_file,
        lastname_file=FLAGS.lastname_file,
        airportcode_file=FLAGS.airportcode_file)
    inter = interaction.Interaction(cg.fact_obj,
                                    skip_greeting=0,
                                    fix_response_candidate=True,
                                    first_ask_prob=0,
                                    random_respond_error=True)
    ct, stats = cg.generate_context(num_samples, output_object=True)
    if FLAGS.verbose:
        print(stats)
    with gfile.Open(FLAGS.output_data,
                    "w") as f_data, gfile.Open(FLAGS.output_kb, "w") as f_kb:
        for i in range(len(ct)):
            if FLAGS.verbose and i % 5000 == 0:
                print((i, "/", len(ct)))
            cus, kb, expected_action = ct[i]
            # action has been standarlized in inter
            utterance, action, _ = inter.generate_dialogue(cus, kb)
            standarlized_intent = utils.standardize_intent(cus.get_json())
            standarlized_action = utils.standardize_action(action)
            standarlized_expected_action = utils.standardize_action(
                expected_action)
            # syntherized data is 100% correct. However, action contains at most one
            # flight. expected_action may contain more than one flight.
            f_data.write(
                json.dumps({
                    "intent": standarlized_intent,
                    "dialogue": utterance,
                    "action": standarlized_action,
                    "expected_action": standarlized_expected_action
                }) + "\n")
            dumped_kb = kb.get_json()
            f_kb.write(json.dumps(dumped_kb) + "\n")
コード例 #23
0
ファイル: core.py プロジェクト: ziyouzizai111/google-research
def write_hparams(hparams, filename):
  """Writes HParams to disk as JSON.

  Args:
    hparams: HParams.
    filename: String output filename.
  """
  with gfile.Open(filename, 'w') as f:
    f.write(hparams.to_json(indent=2, sort_keys=True, separators=(',', ': ')))
コード例 #24
0
def load_origins(segmentation_dir, corner):
    target_path = get_existing_subvolume_path(segmentation_dir, corner, False)
    if target_path is None:
        raise ValueError('Segmentation not found: %s, %s' %
                         (segmentation_dir, corner))

    with gfile.Open(target_path, 'rb') as f:
        data = np.load(f)
        return data['origins'].item()
コード例 #25
0
def save_flags():
    gfile.MakeDirs(FLAGS.train_dir)
    with gfile.Open(os.path.join(FLAGS.train_dir, 'flags.%d' % time.time()),
                    'w') as f:
        for mod, flag_list in FLAGS.flags_by_module_dict().items():
            if (mod.startswith('google3.research.neuromancer.tensorflow')
                    or mod.startswith('/')):
                for flag in flag_list:
                    f.write('%s\n' % flag.serialize())
コード例 #26
0
def main(unused_argv):
    # Load the examples
    vocabulary = set()
    for filename in FLAGS.input_filenames:
        if filename:
            with gfile.Open(os.path.join(FLAGS.data_dir, filename)) as infile:
                for line in infile:
                    if line:
                        gold_query = NLToSQLExample().from_json(
                            json.loads(line)).gold_sql_query

                        for token in gold_query.actions:
                            if token.symbol:
                                vocabulary.add(token.symbol)
    print('Writing vocabulary of size %d to %s' %
          (len(vocabulary), FLAGS.output_path))
    with gfile.Open(FLAGS.output_path, 'w') as ofile:
        ofile.write('\n'.join(list(vocabulary)))
コード例 #27
0
def write_infer_json(data, kb, output_file_src, output_file_tgt
  , output_file_kb):
  """This function write both kb and main data into the files."""
  f_src = gfile.Open(output_file_src, 'w')
  f_tgt = gfile.Open(output_file_tgt, 'w')
  f_kb = gfile.Open(output_file_kb, 'w')
  for entry, entry_kb in zip(data, kb):
    entire_dialogue = entry['dialogue'][:]

    # random_turn = random.randint(0, len(start) - 1)
    for target_turn in range(len(entire_dialogue))[1:]:
      f_kb.write(json_dump(entry_kb) + '\n')
      entry['dialogue'] = entire_dialogue[0:target_turn]
      f_src.write(json_dump(entry) + '\n')
      f_tgt.write(json_dump({'response': entire_dialogue[target_turn]}) + '\n')

  f_src.close()
  f_tgt.close()
  f_kb.close()
コード例 #28
0
 def __init__(self):
   with gfile.Open(CUSTOMDATA_PATH, "rb") as f:
     # load data
     data = np.load(file=f, allow_pickle=True)
   self.images = data["imgs"]
   self.data_shape = list(self.images.shape[1:]) # first dimension [dim0] is dataset size; as list [] because gaussian_encoder_model.py requires it
   self.factor_sizes = data["factor_sizes"]
   self.latent_factor_indices = list(range(len(self.factor_sizes)))
   self.factor_bases = np.prod(self.factor_sizes) / np.cumprod(self.factor_sizes)
   self.state_space = util.SplitDiscreteStateSpace(self.factor_sizes, self.latent_factor_indices)
コード例 #29
0
def write_data(data, output_file_data, output_file_kb, alt_infer=False):
    """This function writes data into a text file."""
    f_data = gfile.Open(output_file_data, 'w')
    f_kb = gfile.Open(output_file_kb, 'w')
    for entry in data:
        f_kb.write(flatten_json(entry['kb']) + '\n')
        new_arr = []
        if alt_infer:
            new_arr = [
                entry['intent'], entry['dialogue'].replace('<eod> ', '')
            ]
        else:
            new_arr = [
                entry['intent'], entry['action'], entry['dialogue'],
                entry['boundaries1']
            ]
            # only boundary1 is used but not 2 because it's not necessary.
        f_data.write('|'.join(new_arr) + '\n')
    f_data.close()
    f_kb.close()
コード例 #30
0
def _load_mesh(filename):
    """Parses a single source file and rescales contained images."""
    with gfile.Open(os.path.join(CARS3D_PATH, filename), "rb") as f:
        mesh = np.einsum("abcde->deabc", sio.loadmat(f)["im"])
    flattened_mesh = mesh.reshape((-1, ) + mesh.shape[2:])
    rescaled_mesh = np.zeros((flattened_mesh.shape[0], 64, 64, 3))
    for i in range(flattened_mesh.shape[0]):
        pic = PIL.Image.fromarray(flattened_mesh[i, :, :, :])
        pic.thumbnail((64, 64, 3), PIL.Image.ANTIALIAS)
        rescaled_mesh[i, :, :, :] = np.array(pic)
    return rescaled_mesh * 1. / 255