def main(unused_argv): logging.set_verbosity(tf.logging.INFO) if not FLAGS.json_prediction_files_pattern: raise ValueError( "The flag --json_prediction_files_pattern must be specified.") if not FLAGS.csv_output_file: raise ValueError("The flag --csv_output_file must be specified.") logging.info( "Looking for prediction files with pattern: %s", FLAGS.json_prediction_files_pattern, ) file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern) logging.info("Found files: %s", file_paths) logging.info("Writing submission file to: %s", FLAGS.csv_output_file) with gfile.Open(FLAGS.csv_output_file, "w+") as output_file: output_file.write(get_csv_header()) for file_path in file_paths: logging.info("processing file: %s", file_path) with gfile.Open(file_path) as input_file: for line in input_file: json_data = json.loads(line) output_file.write(to_csv_row(json_data)) output_file.flush() logging.info("done")
def set_latest_checkpoint(dirname: str, chkpt: str): """Set the latest checkpoint in the checkpoint file. Args: dirname: Directory in which the checkpoint is located. chkpt: Checkpoint prefix. """ chkpt_file = os.path.join(dirname, 'checkpoint') lines = [] if gfile.Exists(chkpt_file): logging.info('Loading preexisting checkpoint file "%s"', chkpt_file) with gfile.Open(chkpt_file) as f: lines = [ l.strip() for l in f.readlines() if l.startswith(b'all_model_checkpoint_paths:') ] else: logging.info('No preexisting checkpoint file "%s"', chkpt_file) with gfile.Open(chkpt_file, 'w') as f: lines = [ '%s\n' % l.strip() for l in ([ 'model_checkpoint_path: "%s"' % chkpt, 'all_model_checkpoint_paths: "%s"' % chkpt ] + lines) ] f.writelines(lines)
def main(unused_argv): tokenizer = FullTokenizer(FLAGS.tokenizer_vocabulary) print('Loading ' + str(FLAGS.dataset_name) + ' dataset from ' + FLAGS.input_filepath) # The debugging file saves all of the processed SQL queries. debugging_file = gfile.Open( os.path.join('/'.join(FLAGS.output_filepath.split('/')[:-1]), FLAGS.dataset_name + '_'.join(FLAGS.splits) + '_gold.txt'), 'w') # The output file will save a sequence of string-serialized JSON objects, one # line per object. output_file = gfile.Open(os.path.join(FLAGS.output_filepath), 'w') if FLAGS.dataset_name.lower() == 'spider': num_examples_created, num_examples_failed = process_spider( output_file, debugging_file, tokenizer) elif FLAGS.dataset_name.lower() == 'wikisql': num_examples_created, num_examples_failed = process_wikisql( output_file, debugging_file, tokenizer) else: num_examples_created, num_examples_failed = process_michigan_datasets( output_file, debugging_file, tokenizer) print('Wrote %s examples, could not annotate %s examples.' % (num_examples_created, num_examples_failed)) debugging_file.write('Wrote %s examples, could not annotate %s examples.' % (num_examples_created, num_examples_failed)) debugging_file.close() output_file.close()
def main(unused_argv): # Get one-hot encoding. mol_weights = pd.Series(_MOL_WEIGHTS) alphabet = [k for k in mol_weights.keys() if not k.startswith(_GROUP)] alphabet = sorted(alphabet) one_hot_encoding = pd.get_dummies(alphabet).astype(int).to_dict( orient='list') with gfile.Open(FLAGS.input_data) as inputf: input_data = pd.read_csv(inputf, sep=',') input_data.rename(columns={ FLAGS.sequence_col: _MOD_SEQUENCE, FLAGS.charge_col: _CHARGE, FLAGS.fragmentation_col: _FRAGMENTATION, FLAGS.analyzer_col: _MASS_ANALYZER }, inplace=True) metadata, _ = preprocess_peptides(input_data, FLAGS.clean_peptides) metadata = metadata.reset_index() check_inputs(metadata, alphabet) # length. json_inputs = generate_json_inputs(metadata, one_hot_encoding) with gfile.Open(os.path.join(FLAGS.output_data_dir, 'input.json'), 'w') as outf: for json_input in json_inputs: outf.write(json.dumps(json_input) + '\n') with gfile.Open(os.path.join(FLAGS.output_data_dir, 'metadata.tsv'), 'w') as outf: metadata.to_csv(outf, sep='\t')
def write_vocabulary(output_file, output_all_vocab_file, word_frequency, frequency_cutoff, keep_non_ascii): special_chars = [ unk_token, start_of_turn1, start_of_turn2, end_of_dialogue ] new_word_frequency = set([]) # if output_filei not None, we write to file, otherwise we don't. if output_file: f = gfile.Open(output_file, 'w') else: f = None # first one should always be unktoken for special_char in special_chars: if f: f.write(special_char + '\n') for key in word_frequency: # We write to the vocabulary only when the key is not empty. # Otherwise tensorflow will complain. if word_frequency[key] >= frequency_cutoff and ( key not in special_chars) and is_ascii(key, keep_non_ascii) and key: if f: f.write(key + '\n') new_word_frequency.add(key) if f: f.close() # all vocab if output_all_vocab_file: with gfile.Open(output_all_vocab_file, 'w') as f2: f2.write(str(word_frequency)) return new_word_frequency
def write_completion(data, output_file_data_src, output_file_data_tar, output_file_kb): """This function write both kb and main data into the files.""" f_data_src = gfile.Open(output_file_data_src, 'w') f_data_tar = gfile.Open(output_file_data_tar, 'w') f_kb = gfile.Open(output_file_kb, 'w') for entry in data: bd1 = entry['boundaries1'].split(' ') bd2 = entry['boundaries2'].split(' ') start = bd1[0:len(bd1) / 2] + bd2[0:len(bd2) / 2] end = bd1[len(bd1) / 2:] + bd2[len(bd2) / 2:] # random_turn = random.randint(0, len(start) - 1) for random_turn in range(len(start)): f_kb.write(flatten_json(entry['kb']) + '\n') # print len(start),len(end),len(bd),random_turn turn_start = int(start[random_turn]) turn_end = int(end[random_turn]) dialogue_split = entry['dialogue'].split(' ') # print turn_start,turn_end dialogue_src = dialogue_split[0:turn_start + 1] dialogue_tar = dialogue_split[turn_start + 1:turn_end + 1] src_arr = [entry['intent'], ' '.join(dialogue_src)] f_data_src.write('|'.join(src_arr) + '\n') tar_arr = [entry['action'], ' '.join(dialogue_tar)] f_data_tar.write('|'.join(tar_arr) + '\n') f_data_src.close() f_data_tar.close() f_kb.close()
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k): with tf.Session() as sess, gfile.Open(out_file_location, "w+") as out_file, gfile.Open(out_file_location+"2", "w+") as out_file2: video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size) latest_checkpoint = tf.train.latest_checkpoint(train_dir) if latest_checkpoint is None: raise Exception("unable to find a checkpoint at location: %s" % train_dir) else: meta_graph_location = latest_checkpoint + ".meta" logging.info("loading meta-graph: " + meta_graph_location) saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) logging.info("restoring variables from " + latest_checkpoint) saver.restore(sess, latest_checkpoint) input_tensor = tf.get_collection("input_batch_raw")[0] num_frames_tensor = tf.get_collection("num_frames")[0] predictions_tensor = tf.get_collection("predictions")[0] # Workaround for num_epochs issue. def set_up_init_ops(variables): init_op_list = [] for variable in list(variables): if "train_input" in variable.name: init_op_list.append(tf.assign(variable, 1)) variables.remove(variable) init_op_list.append(tf.variables_initializer(variables)) return init_op_list sess.run(set_up_init_ops(tf.get_collection_ref( tf.GraphKeys.LOCAL_VARIABLES))) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) num_examples_processed = 0 start_time = time.time() out_file.write("VideoId,LabelConfidencePairs\n") out_file2.write("VideoId,Preds\n") try: while not coord.should_stop(): video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch]) predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val}) now = time.time() num_examples_processed += len(video_batch_val) num_classes = predictions_val.shape[1] logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time)) for line in format_lines(video_id_batch_val, predictions_val, top_k): out_file.write(line) out_file2.write(line2) out_file.flush() out_file2.flush() except tf.errors.OutOfRangeError: logging.info('Done with inference. The output file was written to ' + out_file_location) finally: coord.request_stop() coord.join(threads) sess.close()
def __init__(self, vocab_file, embedding_file, normalize_embeddings=True): with gfile.Open(embedding_file, 'rb') as f: self.embedding_mat = np.load(f) if normalize_embeddings: self.embedding_mat = self.embedding_mat / np.linalg.norm( self.embedding_mat, axis=1, keepdims=True) with gfile.Open(vocab_file, 'r') as f: tks = json.load(f) self.vocab = dict(zip(tks, range(len(tks))))
def write_self_play(data, output_file_data, output_file_kb): f_data = gfile.Open(output_file_data, 'w') f_kb = gfile.Open(output_file_kb, 'w') for entry in data: f_kb.write(flatten_json(entry['kb']) + '\n') new_arr = [entry['intent'], entry['expected_action'] ] # intent and action are both needed f_data.write('|'.join(new_arr) + '\n') f_data.close() f_kb.close()
def load_pickle(pickle_file): """Load a pickle file (py2/3 compatible).""" try: with gfile.Open(pickle_file, 'rb') as f: pickle_data = pickle.load(f) except UnicodeDecodeError as e: with gfile.Open(pickle_file, 'rb') as f: pickle_data = pickle.load(f, encoding='latin1') except Exception as e: print('Unable to load {}: {}'.format(pickle_file, e)) raise return pickle_data
def collect_programs(): saved_programs = {} for i in xrange(FLAGS.id_start, FLAGS.id_end): with gfile.Open( os.path.join(get_experiment_dir(), 'program_shard_{}-{}.json'.format( i, FLAGS.n_epoch)), 'r') as f: program_shard = json.load(f) saved_programs.update(program_shard) saved_program_path = os.path.join(get_experiment_dir(), 'saved_programs.json') with gfile.Open(saved_program_path, 'w') as f: json.dump(saved_programs, f) print 'saved programs are aggregated in {}'.format(saved_program_path)
def write_data(data, output_file_data, output_file_kb): """This function writes data into a text file.""" f_data = gfile.Open(output_file_data, 'w') f_kb = gfile.Open(output_file_kb, 'w') for entry in data: f_kb.write(flatten_json(entry['kb']) + '\n') new_arr = [ entry['intent'], entry['action'], entry['dialogue'], entry['boundaries1'] ] # only boundary1 is used but not 2 because it's not necessary. f_data.write('|'.join(new_arr) + '\n') f_data.close() f_kb.close()
def next_pairs_array(self): arr = numpy.load(gfile.Open(self.train_npy_files[self.next_idx])) indices = range(len(arr)) random.shuffle(indices) arr = arr[indices] self.next_idx = (self.next_idx + 1) % len(self.train_npy_files) return arr
def load_wikisql_tables(filepath): """Loads the WikiSQL tables from a path and reformats as the format.""" dbs = dict() with gfile.Open(filepath) as infile: tables = [json.loads(line) for line in infile if line] for table in tables: db_dict = dict() table_name = table[ 'section_title'] if 'section_title' in table and table[ 'section_title'] else ( table['name'] if 'name' in table else table['page_title']) table_name = normalize_entities(table_name) db_dict[table_name] = list() for column_name, column_type in zip(table['header'], table['types']): if column_type == 'real': column_type = 'number' assert column_type in {'text', 'number'}, column_type column_name = normalize_entities(column_name) db_dict[table_name].append({ 'field name': column_name, 'is primary key': False, 'is foreign key': False, 'type': column_type }) if table['id'] not in dbs: dbs[table['id']] = db_dict return dbs
def main(unused_argv): logging.set_verbosity(tf.logging.INFO) paths = gfile.Glob(FLAGS.input_data_pattern) logging.info("Found %s files.", len(paths)) for path in paths: with gfile.Open(path, "r") as f: first_read = True while True: length_raw = f.read(8) if not length_raw and first_read: logging.fatal("File %s has no data.", path) break elif not length_raw: logging.info("File %s looks good.", path) break else: first_read = False if len(length_raw) != 8: logging.fatal("File ends when reading record length: " + path) break length, = struct.unpack("L", length_raw) # +8 to include the crc values. record = f.read(length + 8) if len(record) != length + 8: logging.fatal("File ends in the middle of a record: " + path) break
def get_optimized_mols(model_dir, ckpt=80000): """Get optimized Molecules. Args: model_dir: String. model directory. ckpt: the checkpoint to load. Returns: List of 800 optimized molecules """ hparams_file = os.path.join(model_dir, 'config.json') with gfile.Open(hparams_file, 'r') as f: hp_dict = json.load(f) hparams = deep_q_networks.get_hparams(**hp_dict) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial( deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=0.0) tf.reset_default_graph() optimized_mol = [] with tf.Session() as sess: dqn.build() model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints) model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt)) for mol in all_mols: logging.info('Eval: %s', mol) environment = molecules_mdp.Molecule( atom_types=set(hparams.atom_types), init_mol=mol, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode, record_path=True) environment.initialize() if hparams.num_bootstrap_heads: head = np.random.randint(hparams.num_bootstrap_heads) else: head = 0 for _ in range(hparams.max_steps_per_episode): steps_left = hparams.max_steps_per_episode - environment.num_steps_taken valid_actions = list(environment.get_valid_actions()) observations = np.vstack([ np.append( deep_q_networks.get_fingerprint(act, hparams), steps_left) for act in valid_actions ]) action = valid_actions[dqn.get_action( observations, head=head, update_epsilon=0.0)] environment.step(action) optimized_mol.append(environment.get_path()) return optimized_mol
def write_batch_as_jpg(batch, filename, pad=1): # import pdb # pdb.set_trace() # pad with zeros to make more visible batch = tf.pad(batch, tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]), constant_values=1.) batch_size = len(batch) grid_shape = int(math.sqrt(batch_size)), int(math.sqrt(batch_size)) image_shape = batch.shape[1:3] num_channels = batch.shape[-1] batch = eval.image_grid(batch, grid_shape=grid_shape, image_shape=image_shape, num_channels=num_channels) batch = tf.image.convert_image_dtype(batch, tf.uint8) batch = tf.squeeze(batch, 0) with gfile.Open(filename, 'wb+') as f: encoded_batch = tf.io.encode_jpeg(batch, name='batch') f.write(encoded_batch.numpy()) return
def download(uri, dst_dir): """Download the given URI. Args: uri: URI to copy (or download) from. dst_dir: path to the directory that will be used. Returns: The path to the downloaded file. """ # Download the URI # Should use context manager with Py3 (with urllib2.urlopen(uri) as response) response = urllib.request.urlopen(uri) filename = response.geturl().split('/')[-1] incomplete_path = os.path.join(dst_dir, '{}.incomplete'.format(filename)) dst_path = os.path.join(dst_dir, filename) # TODO(epot): Could add a shared tqdm instance across parallel download # to display a single shared progression bar. # TODO(b/119663674): Add Google Drive support (cf Ryan code) with gfile.Open(incomplete_path, 'wb') as f: f.write(response.read()) gfile.Rename(incomplete_path, dst_path) return dst_path
def __init__(self, latent_factor_indices=None): DSprites.__init__(self, latent_factor_indices) self.data_shape = [64, 64, 3] with gfile.Open(SCREAM_PATH, "rb") as f: scream = PIL.Image.open(f) scream.thumbnail((350, 274, 3)) self.scream = np.array(scream) * 1. / 255.
def _sync_download(self, url, destination_path): """Synchronous version of `download` method.""" checksum = self._checksumer() session = requests.Session() if _DRIVE_URL.match(url): url = self._get_drive_url(url, session) response = session.get(url, stream=True) if response.status_code != 200: raise DownloadError('Failed to get url %s. HTTP code: %d.' % (url, response.status_code)) fname = _get_filename(response) path = os.path.join(destination_path, fname) size = 0 size_mb = 0 unit_mb = units.MiB self._pbar_dl_size.update_total( int(response.headers.get('Content-length', 0)) // unit_mb) with gfile.Open(path, 'wb') as file_: for block in response.iter_content( chunk_size=io.DEFAULT_BUFFER_SIZE): size += len(block) # Update the progress bar size_mb += len(block) if size_mb > unit_mb: self._pbar_dl_size.update(size_mb // unit_mb) size_mb %= unit_mb checksum.update(block) # TODO(pierrot): Test this is faster than doing checksum in the end # and document results here. file_.write(block) self._pbar_url.update(1) return checksum.hexdigest(), size
def dump_object(object_to_dump, output_path): if not gfile.Exists(output_path): gfile.MakeDirs(os.path.dirname(output_path)) with gfile.Open(output_path, 'w') as wf: joblib.dump(object_to_dump, wf)
def main(argv): del argv # unused. if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks_parent.get_hparams(**json.load(f)) else: hparams = deep_q_networks_parent.get_hparams() environment = BARewardMolecule( discount_factor=hparams.discount_factor, atom_types=set(hparams.atom_types), init_mol= FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks_parent.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial( deep_q_networks_parent.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn_parent.run_training( hparams=hparams, environment=environment, dqn=dqn) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config_sa.json'))
async def train(state, selfplay_processes): """Run training and write a new model to the fsdb models_dir. Args: state: the RL loop State instance. tf_records: a list of paths to TensorFlow records to train on. """ wait_for_training_examples(state, selfplay_processes, FLAGS.min_games_per_iteration) tf_records = await sample_training_examples(state) model_path = os.path.join(fsdb.models_dir(), state.train_model_name) await run( 'python3', 'train.py', '--gpu_device_list={}'.format(','.join(FLAGS.train_devices)), '--flagfile={}'.format(os.path.join(FLAGS.flags_dir, 'train.flags')), '--work_dir={}'.format(fsdb.working_dir()), '--export_path={}'.format(model_path), '--use_extra_features={}'.format(FLAGS.use_extra_features), '--freeze=true', *tf_records) # Append the time elapsed from when the RL was started to when this model # was trained. elapsed = time.time() - state.start_time timestamps_path = os.path.join(fsdb.models_dir(), 'train_times.txt') with gfile.Open(timestamps_path, 'a') as f: print('{:.3f} {}'.format(elapsed, state.train_model_name), file=f) if FLAGS.validate and state.iter_num > 1: try: await validate(state) except Exception as e: logging.error(e)
def main(argv): del argv if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() environment = Molecule(atom_types=set(hparams.atom_types), init_mol=None, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn.run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
async def run(*cmd): """Run the given subprocess command in a coroutine. Args: *cmd: the command to run and its arguments. Returns: The output that the command wrote to stdout as a list of strings, one line per element (stderr output is piped to stdout). Raises: RuntimeError: if the command returns a non-zero result. """ stdout = await checked_run(*cmd) log_path = os.path.join(FLAGS.base_dir, get_cmd_name(cmd) + '.log') with gfile.Open(log_path, 'a') as f: f.write(await expand_cmd_str(cmd)) f.write('\n') f.write(stdout) f.write('\n') # Split stdout into lines. return stdout.split('\n')
def setUpClass(cls): cache_dir = tf.test.get_temp_dir() # Create a dummy file dummy_dir = os.path.join(cache_dir, 'dummy') dummy_filepath = os.path.join(dummy_dir, 'dummy.txt') gfile.MakeDirs(dummy_dir) dummy_file_contents = 'hello world' with gfile.Open(dummy_filepath, 'w') as f: f.write(dummy_file_contents) # File containing compressed archives input_dir = os.path.join(cache_dir, 'to_extract') gfile.MakeDirs(input_dir) dl_manager = download_manager.DownloadManager( cache_dir=cache_dir, mode=util.GenerateMode.REUSE_CACHE_IF_EXISTS, ) cls.dummy_dir = dummy_dir cls.dummy_filepath = dummy_filepath cls.dummy_file_contents = dummy_file_contents cls.input_dir = input_dir cls.dl_manager = dl_manager
def read_df_from_gcs(file_pattern): """Read data from Google Cloud Storage, split into train and validation sets. Assume that the data on GCS is in csv format without header. The column names will be provided through metadata Args: file_pattern: (string) pattern of the files containing training data. For example: [gs://bucket/folder_name/prefix] Returns: pandas.DataFrame """ # Download the files to local /tmp/ folder df_list = [] for filepath in gfile.Glob(file_pattern): with gfile.Open(filepath, 'r') as f: # Assume there is no header df_list.append(pd.read_csv(f, names=metadata.CSV_COLUMNS)) data_df = pd.concat(df_list) return data_df
def get_latest_checkpoint(dirname: str): """Get the latest checkpoint in the directory. Args: dirname: Name of the directory. Returns: Checkpoint prefix string. """ chkpt_file = os.path.join(dirname, 'checkpoint') if not gfile.Exists(chkpt_file): logging.info('File %s does not exist', chkpt_file) return None chkpt_export_folder = os.path.join(dirname, 'export') if not gfile.Exists(chkpt_export_folder): logging.info('Eval export folder %s does not exist', chkpt_export_folder) return None num_lines = 0 with gfile.Open(chkpt_file) as f: for l in f: num_lines += 1 if l.startswith(b'model_checkpoint_path:'): return os.path.basename(l.strip().split()[1][1:-1]) return None
def restore_checkpoint(self, path): """Restores state from the checkpoint at `path`.""" self.log_info('Restoring inference checkpoint: %s', path) with gfile.Open(path, 'r') as f: data = np.load(f) self.segmentation[:] = data['segmentation'] self.seed[:] = data['seed'] self.seg_prob[:] = data['seg_qprob'] self.history_deleted = list(data['history_deleted']) self.history = list(data['history']) self.origins = data['origins'].item() if 'overlaps' in data: self.overlaps = data['overlaps'].item() segmented_voxels = np.sum(self.segmentation != 0) self.counters['voxels-segmented'].Set(segmented_voxels) self._max_id = np.max(self.segmentation) self.movement_policy.restore_state(data['movement_policy']) seed_policy_state = data['seed_policy_state'] # When restoring the state of a previously unused Canvas, the seed # policy will not be defined. We just save the seed policy state here # for future use in .segment_all(). self._seed_policy_state = seed_policy_state self.counters.loads(data['counters'].item()) self.log_info('Inference checkpoint restored.')
def inference_loop(self, saver, model_ckpt_path): output_path = "{}.inference_predicts".format( model_ckpt_path.split('/')[-1]) with tf.Session() as sess, gfile.Open(output_path, "w+") as out_file: sess.run(tf.local_variables_initializer()) saver.restore(sess, model_ckpt_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) num_examples_processed = 0 start_time = time.time() out_file.write("VideoId,LabelConfidencePairs\n") try: while not coord.should_stop(): res = sess.run(self.feed_out) num_examples_processed += len(res["predictions"].shape[0]) logging.info( "num examples processed: %d; elapsed seconds: %.2f " % (num_examples_processed, time.time() - start_time)) for line in format_lines(res["video_id"], res["predictions"], self.config.top_k): out_file.write(line) out_file.flush() except tf.errors.OutOfRangeError: logging.info( 'Done with inference. The output file was written to ' + output_path) finally: coord.request_stop() coord.join(threads)