def write_completion(data, output_file_data_src, output_file_data_tar, output_file_kb): """This function write both kb and main data into the files.""" f_data_src = gfile.Open(output_file_data_src, 'w') f_data_tar = gfile.Open(output_file_data_tar, 'w') f_kb = gfile.Open(output_file_kb, 'w') for entry in data: bd1 = entry['boundaries1'].split(' ') bd2 = entry['boundaries2'].split(' ') start = bd1[0:len(bd1) // 2] + bd2[0:len(bd2) // 2] end = bd1[len(bd1) // 2:] + bd2[len(bd2) // 2:] # random_turn = random.randint(0, len(start) - 1) for random_turn in range(len(start)): f_kb.write(flatten_json(entry['kb']) + '\n') # print len(start),len(end),len(bd),random_turn turn_start = int(start[random_turn]) turn_end = int(end[random_turn]) dialogue_split = entry['dialogue'].split(' ') # print turn_start,turn_end dialogue_src = dialogue_split[0:turn_start + 1] dialogue_tar = dialogue_split[turn_start + 1:turn_end + 1] src_arr = [entry['intent'], ' '.join(dialogue_src)] f_data_src.write('|'.join(src_arr) + '\n') tar_arr = [entry['action'], ' '.join(dialogue_tar)] f_data_tar.write('|'.join(tar_arr) + '\n') f_data_src.close() f_data_tar.close() f_kb.close()
def write_vocabulary(output_file, output_all_vocab_file, word_frequency, frequency_cutoff, keep_non_ascii): special_chars = [ unk_token, start_of_turn1, start_of_turn2, end_of_dialogue ] new_word_frequency = set([]) # if output_filei not None, we write to file, otherwise we don't. if output_file: f = gfile.Open(output_file, 'w') else: f = None # first one should always be unktoken for special_char in special_chars: if f: f.write(special_char + '\n') for key in word_frequency: # We write to the vocabulary only when the key is not empty. # Otherwise tensorflow will complain. if word_frequency[key] >= frequency_cutoff and ( key not in special_chars) and is_ascii(key, keep_non_ascii) and key: if f: f.write(key + '\n') new_word_frequency.add(key) if f: f.close() # all vocab if output_all_vocab_file: with gfile.Open(output_all_vocab_file, 'w') as f2: f2.write(str(word_frequency)) return new_word_frequency
def main(unused_argv): # Get one-hot encoding. mol_weights = pd.Series(_MOL_WEIGHTS) alphabet = [k for k in mol_weights.keys() if not k.startswith(_GROUP)] alphabet = sorted(alphabet) one_hot_encoding = pd.get_dummies(alphabet).astype(int).to_dict( orient='list') with gfile.Open(FLAGS.input_data) as inputf: input_data = pd.read_csv(inputf, sep=',') input_data.rename(columns={ FLAGS.sequence_col: _MOD_SEQUENCE, FLAGS.charge_col: _CHARGE, FLAGS.fragmentation_col: _FRAGMENTATION, FLAGS.analyzer_col: _MASS_ANALYZER }, inplace=True) metadata, _ = preprocess_peptides(input_data, FLAGS.clean_peptides) metadata = metadata.reset_index() check_inputs(metadata, alphabet) # length. json_inputs = generate_json_inputs(metadata, one_hot_encoding) with gfile.Open(os.path.join(FLAGS.output_data_dir, 'input.json'), 'w') as outf: for json_input in json_inputs: outf.write(json.dumps(json_input) + '\n') with gfile.Open(os.path.join(FLAGS.output_data_dir, 'metadata.tsv'), 'w') as outf: metadata.to_csv(outf, sep='\t')
def test_copy(self): """Test copy. """ # Setup and check preconditions. src_file_name = "igfs:///test_copy/1" dst_file_name = "igfs:///test_copy/2" self.assertFalse(gfile.Exists(src_file_name)) self.assertFalse(gfile.Exists(dst_file_name)) with gfile.Open(src_file_name, mode="w") as w: w.write("42") self.assertTrue(gfile.Exists(src_file_name)) self.assertFalse(gfile.Exists(dst_file_name)) # Copy file. gfile.Copy(src_file_name, dst_file_name) # Check that files are identical. self.assertTrue(gfile.Exists(src_file_name)) self.assertTrue(gfile.Exists(dst_file_name)) with gfile.Open(dst_file_name, mode="r") as r: data_v = r.read() self.assertEqual("42", data_v) # Remove file. gfile.Remove(src_file_name) gfile.Remove(dst_file_name) # Check that file was removed. self.assertFalse(gfile.Exists(src_file_name)) self.assertFalse(gfile.Exists(dst_file_name))
def main(unused_argv): tokenizer = FullTokenizer(FLAGS.tokenizer_vocabulary) print('Loading ' + str(FLAGS.dataset_name) + ' dataset from ' + FLAGS.input_filepath) # The debugging file saves all of the processed SQL queries. debugging_file = gfile.Open( os.path.join('/'.join(FLAGS.output_filepath.split('/')[:-1]), FLAGS.dataset_name + '_'.join(FLAGS.splits) + '_gold.txt'), 'w') # The output file will save a sequence of string-serialized JSON objects, one # line per object. output_file = gfile.Open(os.path.join(FLAGS.output_filepath), 'w') if FLAGS.dataset_name.lower() == 'spider': num_examples_created, num_examples_failed = process_spider( output_file, debugging_file, tokenizer) elif FLAGS.dataset_name.lower() == 'wikisql': num_examples_created, num_examples_failed = process_wikisql( output_file, debugging_file, tokenizer) else: num_examples_created, num_examples_failed = process_michigan_datasets( output_file, debugging_file, tokenizer) print('Wrote %s examples, could not annotate %s examples.' % (num_examples_created, num_examples_failed)) debugging_file.write('Wrote %s examples, could not annotate %s examples.' % (num_examples_created, num_examples_failed)) debugging_file.close() output_file.close()
def __init__(self, vocab_file, embedding_file, normalize_embeddings=True): with gfile.Open(embedding_file, 'rb') as f: self.embedding_mat = np.load(f) if normalize_embeddings: self.embedding_mat = self.embedding_mat / np.linalg.norm( self.embedding_mat, axis=1, keepdims=True) with gfile.Open(vocab_file, 'r') as f: tks = json.load(f) self.vocab = dict(zip(tks, range(len(tks))))
def write_self_play(data, output_file_data, output_file_kb): f_data = gfile.Open(output_file_data, 'w') f_kb = gfile.Open(output_file_kb, 'w') for entry in data: f_kb.write(flatten_json(entry['kb']) + '\n') new_arr = [entry['intent'], entry['expected_action'] ] # intent and action are both needed f_data.write('|'.join(new_arr) + '\n') f_data.close() f_kb.close()
def load_pickle(pickle_file): """Load a pickle file (py2/3 compatible).""" try: with gfile.Open(pickle_file, 'rb') as f: pickle_data = pickle.load(f) except UnicodeDecodeError as e: with gfile.Open(pickle_file, 'rb') as f: pickle_data = pickle.load(f, encoding='latin1') except Exception as e: print('Unable to load {}: {}'.format(pickle_file, e)) raise return pickle_data
def collect_programs(): saved_programs = {} for i in xrange(FLAGS.id_start, FLAGS.id_end): with gfile.Open( os.path.join(get_experiment_dir(), 'program_shard_{}-{}.json'.format( i, FLAGS.n_epoch)), 'r') as f: program_shard = json.load(f) saved_programs.update(program_shard) saved_program_path = os.path.join(get_experiment_dir(), 'saved_programs.json') with gfile.Open(saved_program_path, 'w') as f: json.dump(saved_programs, f) print 'saved programs are aggregated in {}'.format(saved_program_path)
def get_nl_sql_pairs(filepath, splits, with_dbs=False): """Gets pairs of natural language and corresponding gold SQL for Michigan.""" with gfile.Open(filepath) as infile: data = json.load(infile) pairs = list() tag = '[' + filepath.split('/')[-1].split('.')[0] + ']' print('Getting examples with tag ' + tag) # The UMichigan data is split by anonymized queries, where values are # anonymized but table/column names are not. However, our experiments are # performed on the original splits of the data. for query in data: # Take the first SQL query only. From their Github documentation: # "Note - we only use the first query, but retain the variants for # completeness" anonymized_sql = query['sql'][0] # It's also associated with a number of natural language examples, which # also contain anonymous tokens. Save the de-anonymized utterance and query. for example in query['sentences']: if example['question-split'] not in splits: continue nl = example['text'] sql = anonymized_sql # Go through the anonymized values and replace them in both the natural # language and the SQL. # # It's very important to sort these in descending order. If one is a # substring of the other, it shouldn't be replaced first lest it ruin the # replacement of the superstring. for variable_name, value in sorted( example['variables'].items(), key=lambda x: len(x[0]), reverse=True): if not value: # TODO(alanesuhr) While the Michigan repo says to use a - here, the # thing that works is using a % and replacing = with LIKE. # # It's possible that I should remove such clauses from the SQL, as # long as they lead to the same table result. They don't align well # to the natural language at least. # # See: https://github.com/jkkummerfeld/text2sql-data/tree/master/data value = '%' nl = nl.replace(variable_name, value) sql = sql.replace(variable_name, value) # In the case that we replaced an empty anonymized value with %, make it # compilable new allowing equality with any string. sql = sql.replace('= "%"', 'LIKE "%"') if with_dbs: pairs.append((nl, sql, example['table-id'])) else: pairs.append((nl, sql)) return pairs
def main(argv): del argv # unused. if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() environment = DockingRewardMolecule( discount_factor=hparams.discount_factor, atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn.run_training(hparams=hparams, environment=environment, dqn=dqn) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
def _inner_iter(self, fpath): with gfile.Open(fpath, 'r') as fh: rest_buffer = [] aware_headers = True read_finished = False while not read_finished: dict_reader, rest_buffer, read_finished = \ self._make_csv_dict_reader(fh, rest_buffer, aware_headers) aware_headers = False if self._headers is None: self._headers = dict_reader.fieldnames elif self._headers != dict_reader.fieldnames: logging.fatal("the schema of %s is %s, mismatch "\ "with previous %s", fpath, self._headers, dict_reader.fieldnames) traceback.print_stack() os._exit(-1) # pylint: disable=protected-access self._validator.check_csv_header(self._headers) # check invalid character for headers for raw in dict_reader: if not self._validator.check_csv_record( raw, len(self._headers)): continue yield CsvItem(raw)
def load_wikisql_tables(filepath): """Loads the WikiSQL tables from a path and reformats as the format.""" dbs = dict() with gfile.Open(filepath) as infile: tables = [json.loads(line) for line in infile if line] for table in tables: db_dict = dict() table_name = table[ 'section_title'] if 'section_title' in table and table[ 'section_title'] else ( table['name'] if 'name' in table else table['page_title']) table_name = normalize_entities(table_name) db_dict[table_name] = list() for column_name, column_type in zip(table['header'], table['types']): if column_type == 'real': column_type = 'number' assert column_type in {'text', 'number'}, column_type column_name = normalize_entities(column_name) db_dict[table_name].append({ 'field name': column_name, 'is primary key': False, 'is foreign key': False, 'type': column_type }) if table['id'] not in dbs: dbs[table['id']] = db_dict return dbs
def get_optimized_mols(model_dir, ckpt=80000): """Get optimized Molecules. Args: model_dir: String. model directory. ckpt: the checkpoint to load. Returns: List of 800 optimized molecules """ hparams_file = os.path.join(model_dir, 'config.json') with gfile.Open(hparams_file, 'r') as f: hp_dict = json.load(f) hparams = deep_q_networks.get_hparams(**hp_dict) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=0.0) tf.reset_default_graph() optimized_mol = [] with tf.Session() as sess: dqn.build() model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints) model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt)) for mol in all_mols: logging.info('Eval: %s', mol) environment = molecules_mdp.Molecule( atom_types=set(hparams.atom_types), init_mol=mol, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode, record_path=True) environment.initialize() if hparams.num_bootstrap_heads: head = np.random.randint(hparams.num_bootstrap_heads) else: head = 0 for _ in range(hparams.max_steps_per_episode): steps_left = hparams.max_steps_per_episode - environment.num_steps_taken valid_actions = list(environment.get_valid_actions()) observations = np.vstack([ np.append(deep_q_networks.get_fingerprint(act, hparams), steps_left) for act in valid_actions ]) action = valid_actions[dqn.get_action(observations, head=head, update_epsilon=0.0)] environment.step(action) optimized_mol.append(environment.get_path()) return optimized_mol
def __init__(self, latent_factor_indices=None): DSprites.__init__(self, latent_factor_indices) self.data_shape = [64, 64, 3] with gfile.Open(SCREAM_PATH, "rb") as f: scream = PIL.Image.open(f) scream.thumbnail((350, 274, 3)) self.scream = np.array(scream) * 1. / 255.
def setUp(self) -> None: logging.getLogger().setLevel(logging.DEBUG) self._data_portal_name = 'test_data_portal_job_manager' self._kvstore = DBClient('etcd', True) self._portal_input_base_dir = './portal_input_dir' self._portal_output_base_dir = './portal_output_dir' self._raw_data_publish_dir = 'raw_data_publish_dir' if gfile.Exists(self._portal_input_base_dir): gfile.DeleteRecursively(self._portal_input_base_dir) gfile.MakeDirs(self._portal_input_base_dir) self._data_fnames = ['1001/{}.data'.format(i) for i in range(100)] self._data_fnames_without_success = \ ['1002/{}.data'.format(i) for i in range(100)] self._csv_fnames = ['1003/{}.csv'.format(i) for i in range(100)] self._unused_fnames = ['{}.xx'.format(100)] self._all_fnames = self._data_fnames + \ self._data_fnames_without_success + \ self._csv_fnames + self._unused_fnames all_fnames_with_success = ['1001/_SUCCESS'] + ['1003/_SUCCESS'] +\ self._all_fnames for fname in all_fnames_with_success: fpath = os.path.join(self._portal_input_base_dir, fname) gfile.MakeDirs(os.path.dirname(fpath)) with gfile.Open(fpath, "w") as f: f.write('xxx')
def test_list_directory(self): """Test list directory. """ # Setup and check preconditions. gfile.MkDir(self.prefix() + ":///test_list_directory") gfile.MkDir(self.prefix() + ":///test_list_directory/2") gfile.MkDir(self.prefix() + ":///test_list_directory/4") dir_name = self.prefix() + ":///test_list_directory" file_names = [ self.prefix() + ":///test_list_directory/1", self.prefix() + ":///test_list_directory/2/3" ] ch_dir_names = [ self.prefix() + ":///test_list_directory/4", ] for file_name in file_names: with gfile.Open(file_name, mode="w") as w: w.write("") for ch_dir_name in ch_dir_names: gfile.MkDir(ch_dir_name) ls_expected_result = file_names + ch_dir_names # Get list of files in directory. ls_result = gfile.ListDirectory(dir_name) # Check that list of files is correct. self.assertEqual(len(ls_expected_result), len(ls_result)) for e in ["1", "2", "4"]: self.assertTrue(e in ls_result, msg="Result doesn't contain '%s'" % e)
def main(unused_argv): # Load the examples vocabulary = set() valid_filenames = [ filename for filename in FLAGS.input_filenames if filename ] for filename in valid_filenames: with open(os.path.join(FLAGS.data_dir, filename)) as infile: for line in infile: if not line: continue symbols = get_symbol(line) new_symbols = [ symbol for symbol in symbols if symbol not in vocabulary ] if new_symbols: print(new_symbols) for symbol in symbols: vocabulary.add(symbol) print("Writing vocabulary of size %d to %s" % (len(vocabulary), FLAGS.output_path)) with gfile.Open(FLAGS.output_path, "w") as ofile: ofile.write("\n".join(list(vocabulary)))
def main(argv): del argv if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() hparams.override_from_dict( {'max_steps_per_episode': FLAGS.max_steps_per_episode}) environment = Molecule(atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn.run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
def get_prefix_kvs(self, prefix, ignore_prefix=False): kvs = [] target_path = self._generate_path(prefix, with_meta=False) cur_paths = [target_path] children_paths = [] while cur_paths: for path in cur_paths: filenames = [] try: if gfile.IsDirectory(path): filenames = gfile.ListDirectory(path) except Exception as e: # pylint: disable=broad-except logging.warning("get prefix kvs %s failed, " " reason: %s", path, str(e)) break for filename in sorted(filenames): file_path = "/".join([path, filename]) if gfile.IsDirectory(file_path): children_paths.append(file_path) else: if ignore_prefix and path == target_path: continue nkey = self.normalize_output_key( path, self._base_dir).encode() with gfile.Open(file_path, 'rb') as file: kvs.append((nkey, file.read())) cur_paths = children_paths children_paths = [] return kvs
def main(unused_argv): metadata = pd.read_csv(gfile.Open(FLAGS.metadata_file), sep='\t') # Read DeepMass:Prism outputs and merge with metadata. outputs = [] for filen in gfile.Glob(FLAGS.input_data_pattern): with gfile.Open(filen) as infile: if FLAGS.batch_prediction: out_df = pd.read_json(infile, lines=True) else: out_df = json.load(infile) out_df = pd.DataFrame(out_df['predictions']) out_df = out_df.merge(metadata, left_on='key', right_on='index', how='left') outputs.append(out_df) outputs = pd.concat(outputs) outputs = outputs.apply(reformat_outputs, args=(int(FLAGS.label_dim), FLAGS.neutral_losses), axis=1) # Read additional features. if FLAGS.add_feature_names is not None: outputs_drip = [] for filen in gfile.Glob(FLAGS.add_input_data_pattern): with gfile.Open(filen) as infile: if FLAGS.batch_prediction: out_df = pd.read_json(infile, lines=True) out_df = pd.DataFrame(out_df['outputs'].tolist(), columns=FLAGS.add_feature_names, index=out_df['key']) else: pass outputs_drip.append(out_df) outputs_drip = pd.concat(outputs_drip) outputs = outputs.merge(outputs_drip, how='left', left_on='key', right_index=True) # Write to a file. with gfile.Open(os.path.join(FLAGS.output_data_dir, 'outputs.tsv'), 'w') as outf: outputs.to_csv(outf, sep='\t', index=False)
def main(FLAGS): if FLAGS.verbose: print("Number of samples to generate: ", FLAGS.num_samples) print("Output_kb: ", FLAGS.output_kb) print("Output_data: ", FLAGS.output_data) num_samples = FLAGS.num_samples cg = context_generator_lib.ContextGenerator( num_candidate_airports=FLAGS.num_candidate_airports, book_window=FLAGS.book_window, num_db_record=FLAGS.num_db_record, firstname_file=FLAGS.firstname_file, lastname_file=FLAGS.lastname_file, airportcode_file=FLAGS.airportcode_file) inter = interaction.Interaction(cg.fact_obj, skip_greeting=0, fix_response_candidate=True, first_ask_prob=0, random_respond_error=True) ct, stats = cg.generate_context(num_samples, output_object=True) if FLAGS.verbose: print(stats) with gfile.Open(FLAGS.output_data, "w") as f_data, gfile.Open(FLAGS.output_kb, "w") as f_kb: for i in range(len(ct)): if FLAGS.verbose and i % 5000 == 0: print((i, "/", len(ct))) cus, kb, expected_action = ct[i] # action has been standarlized in inter utterance, action, _ = inter.generate_dialogue(cus, kb) standarlized_intent = utils.standardize_intent(cus.get_json()) standarlized_action = utils.standardize_action(action) standarlized_expected_action = utils.standardize_action( expected_action) # syntherized data is 100% correct. However, action contains at most one # flight. expected_action may contain more than one flight. f_data.write( json.dumps({ "intent": standarlized_intent, "dialogue": utterance, "action": standarlized_action, "expected_action": standarlized_expected_action }) + "\n") dumped_kb = kb.get_json() f_kb.write(json.dumps(dumped_kb) + "\n")
def write_hparams(hparams, filename): """Writes HParams to disk as JSON. Args: hparams: HParams. filename: String output filename. """ with gfile.Open(filename, 'w') as f: f.write(hparams.to_json(indent=2, sort_keys=True, separators=(',', ': ')))
def load_origins(segmentation_dir, corner): target_path = get_existing_subvolume_path(segmentation_dir, corner, False) if target_path is None: raise ValueError('Segmentation not found: %s, %s' % (segmentation_dir, corner)) with gfile.Open(target_path, 'rb') as f: data = np.load(f) return data['origins'].item()
def save_flags(): gfile.MakeDirs(FLAGS.train_dir) with gfile.Open(os.path.join(FLAGS.train_dir, 'flags.%d' % time.time()), 'w') as f: for mod, flag_list in FLAGS.flags_by_module_dict().items(): if (mod.startswith('google3.research.neuromancer.tensorflow') or mod.startswith('/')): for flag in flag_list: f.write('%s\n' % flag.serialize())
def main(unused_argv): # Load the examples vocabulary = set() for filename in FLAGS.input_filenames: if filename: with gfile.Open(os.path.join(FLAGS.data_dir, filename)) as infile: for line in infile: if line: gold_query = NLToSQLExample().from_json( json.loads(line)).gold_sql_query for token in gold_query.actions: if token.symbol: vocabulary.add(token.symbol) print('Writing vocabulary of size %d to %s' % (len(vocabulary), FLAGS.output_path)) with gfile.Open(FLAGS.output_path, 'w') as ofile: ofile.write('\n'.join(list(vocabulary)))
def write_infer_json(data, kb, output_file_src, output_file_tgt , output_file_kb): """This function write both kb and main data into the files.""" f_src = gfile.Open(output_file_src, 'w') f_tgt = gfile.Open(output_file_tgt, 'w') f_kb = gfile.Open(output_file_kb, 'w') for entry, entry_kb in zip(data, kb): entire_dialogue = entry['dialogue'][:] # random_turn = random.randint(0, len(start) - 1) for target_turn in range(len(entire_dialogue))[1:]: f_kb.write(json_dump(entry_kb) + '\n') entry['dialogue'] = entire_dialogue[0:target_turn] f_src.write(json_dump(entry) + '\n') f_tgt.write(json_dump({'response': entire_dialogue[target_turn]}) + '\n') f_src.close() f_tgt.close() f_kb.close()
def __init__(self): with gfile.Open(CUSTOMDATA_PATH, "rb") as f: # load data data = np.load(file=f, allow_pickle=True) self.images = data["imgs"] self.data_shape = list(self.images.shape[1:]) # first dimension [dim0] is dataset size; as list [] because gaussian_encoder_model.py requires it self.factor_sizes = data["factor_sizes"] self.latent_factor_indices = list(range(len(self.factor_sizes))) self.factor_bases = np.prod(self.factor_sizes) / np.cumprod(self.factor_sizes) self.state_space = util.SplitDiscreteStateSpace(self.factor_sizes, self.latent_factor_indices)
def write_data(data, output_file_data, output_file_kb, alt_infer=False): """This function writes data into a text file.""" f_data = gfile.Open(output_file_data, 'w') f_kb = gfile.Open(output_file_kb, 'w') for entry in data: f_kb.write(flatten_json(entry['kb']) + '\n') new_arr = [] if alt_infer: new_arr = [ entry['intent'], entry['dialogue'].replace('<eod> ', '') ] else: new_arr = [ entry['intent'], entry['action'], entry['dialogue'], entry['boundaries1'] ] # only boundary1 is used but not 2 because it's not necessary. f_data.write('|'.join(new_arr) + '\n') f_data.close() f_kb.close()
def _load_mesh(filename): """Parses a single source file and rescales contained images.""" with gfile.Open(os.path.join(CARS3D_PATH, filename), "rb") as f: mesh = np.einsum("abcde->deabc", sio.loadmat(f)["im"]) flattened_mesh = mesh.reshape((-1, ) + mesh.shape[2:]) rescaled_mesh = np.zeros((flattened_mesh.shape[0], 64, 64, 3)) for i in range(flattened_mesh.shape[0]): pic = PIL.Image.fromarray(flattened_mesh[i, :, :, :]) pic.thumbnail((64, 64, 3), PIL.Image.ANTIALIAS) rescaled_mesh[i, :, :, :] = np.array(pic) return rescaled_mesh * 1. / 255