Ejemplo n.º 1
0
def main(unused_argv):
    logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.json_prediction_files_pattern:
        raise ValueError(
            "The flag --json_prediction_files_pattern must be specified.")

    if not FLAGS.csv_output_file:
        raise ValueError("The flag --csv_output_file must be specified.")

    logging.info(
        "Looking for prediction files with pattern: %s",
        FLAGS.json_prediction_files_pattern,
    )

    file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern)
    logging.info("Found files: %s", file_paths)

    logging.info("Writing submission file to: %s", FLAGS.csv_output_file)
    with gfile.Open(FLAGS.csv_output_file, "w+") as output_file:
        output_file.write(get_csv_header())

        for file_path in file_paths:
            logging.info("processing file: %s", file_path)

            with gfile.Open(file_path) as input_file:

                for line in input_file:
                    json_data = json.loads(line)
                    output_file.write(to_csv_row(json_data))

        output_file.flush()
    logging.info("done")
def set_latest_checkpoint(dirname: str, chkpt: str):
    """Set the latest checkpoint in the checkpoint file.

  Args:
    dirname: Directory in which the checkpoint is located.
    chkpt: Checkpoint prefix.
  """
    chkpt_file = os.path.join(dirname, 'checkpoint')
    lines = []
    if gfile.Exists(chkpt_file):
        logging.info('Loading preexisting checkpoint file "%s"', chkpt_file)
        with gfile.Open(chkpt_file) as f:
            lines = [
                l.strip() for l in f.readlines()
                if l.startswith(b'all_model_checkpoint_paths:')
            ]
    else:
        logging.info('No preexisting checkpoint file "%s"', chkpt_file)
    with gfile.Open(chkpt_file, 'w') as f:
        lines = [
            '%s\n' % l.strip() for l in ([
                'model_checkpoint_path: "%s"' % chkpt,
                'all_model_checkpoint_paths: "%s"' % chkpt
            ] + lines)
        ]
        f.writelines(lines)
Ejemplo n.º 3
0
def main(unused_argv):
  tokenizer = FullTokenizer(FLAGS.tokenizer_vocabulary)

  print('Loading ' + str(FLAGS.dataset_name) + ' dataset from ' +
        FLAGS.input_filepath)

  # The debugging file saves all of the processed SQL queries.
  debugging_file = gfile.Open(
      os.path.join('/'.join(FLAGS.output_filepath.split('/')[:-1]),
                   FLAGS.dataset_name + '_'.join(FLAGS.splits) + '_gold.txt'),
      'w')

  # The output file will save a sequence of string-serialized JSON objects, one
  # line per object.
  output_file = gfile.Open(os.path.join(FLAGS.output_filepath), 'w')

  if FLAGS.dataset_name.lower() == 'spider':
    num_examples_created, num_examples_failed = process_spider(
        output_file, debugging_file, tokenizer)
  elif FLAGS.dataset_name.lower() == 'wikisql':
    num_examples_created, num_examples_failed = process_wikisql(
        output_file, debugging_file, tokenizer)
  else:
    num_examples_created, num_examples_failed = process_michigan_datasets(
        output_file, debugging_file, tokenizer)

  print('Wrote %s examples, could not annotate %s examples.' %
        (num_examples_created, num_examples_failed))
  debugging_file.write('Wrote %s examples, could not annotate %s examples.' %
                       (num_examples_created, num_examples_failed))
  debugging_file.close()
  output_file.close()
Ejemplo n.º 4
0
def main(unused_argv):

    # Get one-hot encoding.
    mol_weights = pd.Series(_MOL_WEIGHTS)
    alphabet = [k for k in mol_weights.keys() if not k.startswith(_GROUP)]
    alphabet = sorted(alphabet)
    one_hot_encoding = pd.get_dummies(alphabet).astype(int).to_dict(
        orient='list')

    with gfile.Open(FLAGS.input_data) as inputf:
        input_data = pd.read_csv(inputf, sep=',')
    input_data.rename(columns={
        FLAGS.sequence_col: _MOD_SEQUENCE,
        FLAGS.charge_col: _CHARGE,
        FLAGS.fragmentation_col: _FRAGMENTATION,
        FLAGS.analyzer_col: _MASS_ANALYZER
    },
                      inplace=True)

    metadata, _ = preprocess_peptides(input_data, FLAGS.clean_peptides)
    metadata = metadata.reset_index()

    check_inputs(metadata, alphabet)

    # length.
    json_inputs = generate_json_inputs(metadata, one_hot_encoding)
    with gfile.Open(os.path.join(FLAGS.output_data_dir, 'input.json'),
                    'w') as outf:
        for json_input in json_inputs:
            outf.write(json.dumps(json_input) + '\n')
    with gfile.Open(os.path.join(FLAGS.output_data_dir, 'metadata.tsv'),
                    'w') as outf:
        metadata.to_csv(outf, sep='\t')
Ejemplo n.º 5
0
def write_vocabulary(output_file, output_all_vocab_file, word_frequency,
                     frequency_cutoff, keep_non_ascii):
    special_chars = [
        unk_token, start_of_turn1, start_of_turn2, end_of_dialogue
    ]
    new_word_frequency = set([])
    # if output_filei not None, we write to file, otherwise we don't.
    if output_file:
        f = gfile.Open(output_file, 'w')
    else:
        f = None

    # first one should always be unktoken
    for special_char in special_chars:
        if f: f.write(special_char + '\n')
    for key in word_frequency:
        # We write to the vocabulary only when the key is not empty.
        # Otherwise tensorflow will complain.
        if word_frequency[key] >= frequency_cutoff and (
                key not in special_chars) and is_ascii(key,
                                                       keep_non_ascii) and key:
            if f:
                f.write(key + '\n')
            new_word_frequency.add(key)

    if f: f.close()
    # all vocab
    if output_all_vocab_file:
        with gfile.Open(output_all_vocab_file, 'w') as f2:
            f2.write(str(word_frequency))
    return new_word_frequency
Ejemplo n.º 6
0
def write_completion(data, output_file_data_src, output_file_data_tar,
                     output_file_kb):
    """This function write both kb and main data into the files."""
    f_data_src = gfile.Open(output_file_data_src, 'w')
    f_data_tar = gfile.Open(output_file_data_tar, 'w')
    f_kb = gfile.Open(output_file_kb, 'w')
    for entry in data:
        bd1 = entry['boundaries1'].split(' ')
        bd2 = entry['boundaries2'].split(' ')
        start = bd1[0:len(bd1) / 2] + bd2[0:len(bd2) / 2]
        end = bd1[len(bd1) / 2:] + bd2[len(bd2) / 2:]
        # random_turn = random.randint(0, len(start) - 1)
        for random_turn in range(len(start)):
            f_kb.write(flatten_json(entry['kb']) + '\n')
            # print len(start),len(end),len(bd),random_turn
            turn_start = int(start[random_turn])
            turn_end = int(end[random_turn])
            dialogue_split = entry['dialogue'].split(' ')
            # print turn_start,turn_end
            dialogue_src = dialogue_split[0:turn_start + 1]
            dialogue_tar = dialogue_split[turn_start + 1:turn_end + 1]
            src_arr = [entry['intent'], ' '.join(dialogue_src)]
            f_data_src.write('|'.join(src_arr) + '\n')
            tar_arr = [entry['action'], ' '.join(dialogue_tar)]
            f_data_tar.write('|'.join(tar_arr) + '\n')

    f_data_src.close()
    f_data_tar.close()
    f_kb.close()
Ejemplo n.º 7
0
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
  with tf.Session() as sess, gfile.Open(out_file_location, "w+") as out_file, gfile.Open(out_file_location+"2", "w+") as out_file2:
    video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
    latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if latest_checkpoint is None:
      raise Exception("unable to find a checkpoint at location: %s" % train_dir)
    else:
      meta_graph_location = latest_checkpoint + ".meta"
      logging.info("loading meta-graph: " + meta_graph_location)
    saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
    logging.info("restoring variables from " + latest_checkpoint)
    saver.restore(sess, latest_checkpoint)
    input_tensor = tf.get_collection("input_batch_raw")[0]
    num_frames_tensor = tf.get_collection("num_frames")[0]
    predictions_tensor = tf.get_collection("predictions")[0]

    # Workaround for num_epochs issue.
    def set_up_init_ops(variables):
      init_op_list = []
      for variable in list(variables):
        if "train_input" in variable.name:
          init_op_list.append(tf.assign(variable, 1))
          variables.remove(variable)
      init_op_list.append(tf.variables_initializer(variables))
      return init_op_list

    sess.run(set_up_init_ops(tf.get_collection_ref(
        tf.GraphKeys.LOCAL_VARIABLES)))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    num_examples_processed = 0
    start_time = time.time()
    out_file.write("VideoId,LabelConfidencePairs\n")
    out_file2.write("VideoId,Preds\n")

    try:
      while not coord.should_stop():
          video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
          predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
          now = time.time()
          num_examples_processed += len(video_batch_val)
          num_classes = predictions_val.shape[1]
          logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
          for line in format_lines(video_id_batch_val, predictions_val, top_k):
            out_file.write(line)
            out_file2.write(line2)
          out_file.flush()
          out_file2.flush()


    except tf.errors.OutOfRangeError:
        logging.info('Done with inference. The output file was written to ' + out_file_location)
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()
 def __init__(self, vocab_file, embedding_file, normalize_embeddings=True):
     with gfile.Open(embedding_file, 'rb') as f:
         self.embedding_mat = np.load(f)
     if normalize_embeddings:
         self.embedding_mat = self.embedding_mat / np.linalg.norm(
             self.embedding_mat, axis=1, keepdims=True)
     with gfile.Open(vocab_file, 'r') as f:
         tks = json.load(f)
     self.vocab = dict(zip(tks, range(len(tks))))
Ejemplo n.º 9
0
def write_self_play(data, output_file_data, output_file_kb):
    f_data = gfile.Open(output_file_data, 'w')
    f_kb = gfile.Open(output_file_kb, 'w')
    for entry in data:
        f_kb.write(flatten_json(entry['kb']) + '\n')
        new_arr = [entry['intent'], entry['expected_action']
                   ]  # intent and action are both needed
        f_data.write('|'.join(new_arr) + '\n')
    f_data.close()
    f_kb.close()
Ejemplo n.º 10
0
def load_pickle(pickle_file):
    """Load a pickle file (py2/3 compatible)."""
    try:
        with gfile.Open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f)
    except UnicodeDecodeError as e:
        with gfile.Open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f, encoding='latin1')
    except Exception as e:
        print('Unable to load {}: {}'.format(pickle_file, e))
        raise
    return pickle_data
Ejemplo n.º 11
0
def collect_programs():
  saved_programs = {}
  for i in xrange(FLAGS.id_start, FLAGS.id_end):
    with gfile.Open(
        os.path.join(get_experiment_dir(), 'program_shard_{}-{}.json'.format(
            i, FLAGS.n_epoch)), 'r') as f:
      program_shard = json.load(f)
      saved_programs.update(program_shard)
  saved_program_path = os.path.join(get_experiment_dir(), 'saved_programs.json')
  with gfile.Open(saved_program_path, 'w') as f:
    json.dump(saved_programs, f)
  print 'saved programs are aggregated in {}'.format(saved_program_path)
Ejemplo n.º 12
0
def write_data(data, output_file_data, output_file_kb):
    """This function writes data into a text file."""
    f_data = gfile.Open(output_file_data, 'w')
    f_kb = gfile.Open(output_file_kb, 'w')
    for entry in data:
        f_kb.write(flatten_json(entry['kb']) + '\n')
        new_arr = [
            entry['intent'], entry['action'], entry['dialogue'],
            entry['boundaries1']
        ]
        # only boundary1 is used but not 2 because it's not necessary.
        f_data.write('|'.join(new_arr) + '\n')
    f_data.close()
    f_kb.close()
Ejemplo n.º 13
0
 def next_pairs_array(self):
   arr = numpy.load(gfile.Open(self.train_npy_files[self.next_idx]))
   indices = range(len(arr))
   random.shuffle(indices)
   arr = arr[indices]
   self.next_idx = (self.next_idx + 1) % len(self.train_npy_files)
   return arr
Ejemplo n.º 14
0
def load_wikisql_tables(filepath):
    """Loads the WikiSQL tables from a path and reformats as the format."""
    dbs = dict()
    with gfile.Open(filepath) as infile:
        tables = [json.loads(line) for line in infile if line]

    for table in tables:
        db_dict = dict()
        table_name = table[
            'section_title'] if 'section_title' in table and table[
                'section_title'] else (
                    table['name'] if 'name' in table else table['page_title'])

        table_name = normalize_entities(table_name)

        db_dict[table_name] = list()
        for column_name, column_type in zip(table['header'], table['types']):
            if column_type == 'real':
                column_type = 'number'
            assert column_type in {'text', 'number'}, column_type
            column_name = normalize_entities(column_name)

            db_dict[table_name].append({
                'field name': column_name,
                'is primary key': False,
                'is foreign key': False,
                'type': column_type
            })
        if table['id'] not in dbs:
            dbs[table['id']] = db_dict

    return dbs
Ejemplo n.º 15
0
def main(unused_argv):
    logging.set_verbosity(tf.logging.INFO)
    paths = gfile.Glob(FLAGS.input_data_pattern)
    logging.info("Found %s files.", len(paths))
    for path in paths:
        with gfile.Open(path, "r") as f:
            first_read = True
            while True:
                length_raw = f.read(8)
                if not length_raw and first_read:
                    logging.fatal("File %s has no data.", path)
                    break
                elif not length_raw:
                    logging.info("File %s looks good.", path)
                    break
                else:
                    first_read = False
                if len(length_raw) != 8:
                    logging.fatal("File ends when reading record length: " +
                                  path)
                    break
                length, = struct.unpack("L", length_raw)
                # +8 to include the crc values.
                record = f.read(length + 8)
                if len(record) != length + 8:
                    logging.fatal("File ends in the middle of a record: " +
                                  path)
                    break
Ejemplo n.º 16
0
def get_optimized_mols(model_dir, ckpt=80000):
  """Get optimized Molecules.

  Args:
    model_dir: String. model directory.
    ckpt: the checkpoint to load.

  Returns:
    List of 800 optimized molecules
  """
  hparams_file = os.path.join(model_dir, 'config.json')
  with gfile.Open(hparams_file, 'r') as f:
    hp_dict = json.load(f)
    hparams = deep_q_networks.get_hparams(**hp_dict)

  dqn = deep_q_networks.DeepQNetwork(
      input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
      q_fn=functools.partial(
          deep_q_networks.multi_layer_model, hparams=hparams),
      optimizer=hparams.optimizer,
      grad_clipping=hparams.grad_clipping,
      num_bootstrap_heads=hparams.num_bootstrap_heads,
      gamma=hparams.gamma,
      epsilon=0.0)

  tf.reset_default_graph()
  optimized_mol = []
  with tf.Session() as sess:
    dqn.build()
    model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints)
    model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt))
    for mol in all_mols:
      logging.info('Eval: %s', mol)
      environment = molecules_mdp.Molecule(
          atom_types=set(hparams.atom_types),
          init_mol=mol,
          allow_removal=hparams.allow_removal,
          allow_no_modification=hparams.allow_no_modification,
          allow_bonds_between_rings=hparams.allow_bonds_between_rings,
          allowed_ring_sizes=set(hparams.allowed_ring_sizes),
          max_steps=hparams.max_steps_per_episode,
          record_path=True)
      environment.initialize()
      if hparams.num_bootstrap_heads:
        head = np.random.randint(hparams.num_bootstrap_heads)
      else:
        head = 0
      for _ in range(hparams.max_steps_per_episode):
        steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
        valid_actions = list(environment.get_valid_actions())
        observations = np.vstack([
            np.append(
                deep_q_networks.get_fingerprint(act, hparams), steps_left)
            for act in valid_actions
        ])
        action = valid_actions[dqn.get_action(
            observations, head=head, update_epsilon=0.0)]
        environment.step(action)
      optimized_mol.append(environment.get_path())
  return optimized_mol
Ejemplo n.º 17
0
def write_batch_as_jpg(batch, filename, pad=1):

    #  import pdb
    #  pdb.set_trace()

    # pad with zeros to make more visible
    batch = tf.pad(batch,
                   tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]),
                   constant_values=1.)

    batch_size = len(batch)
    grid_shape = int(math.sqrt(batch_size)), int(math.sqrt(batch_size))
    image_shape = batch.shape[1:3]
    num_channels = batch.shape[-1]

    batch = eval.image_grid(batch,
                            grid_shape=grid_shape,
                            image_shape=image_shape,
                            num_channels=num_channels)
    batch = tf.image.convert_image_dtype(batch, tf.uint8)
    batch = tf.squeeze(batch, 0)

    with gfile.Open(filename, 'wb+') as f:
        encoded_batch = tf.io.encode_jpeg(batch, name='batch')
        f.write(encoded_batch.numpy())

    return
Ejemplo n.º 18
0
def download(uri, dst_dir):
  """Download the given URI.

  Args:
    uri: URI to copy (or download) from.
    dst_dir: path to the directory that will be used.

  Returns:
    The path to the downloaded file.
  """
  # Download the URI
  # Should use context manager with Py3 (with urllib2.urlopen(uri) as response)
  response = urllib.request.urlopen(uri)
  filename = response.geturl().split('/')[-1]
  incomplete_path = os.path.join(dst_dir, '{}.incomplete'.format(filename))
  dst_path = os.path.join(dst_dir, filename)

  # TODO(epot): Could add a shared tqdm instance across parallel download
  # to display a single shared progression bar.

  # TODO(b/119663674): Add Google Drive support (cf Ryan code)

  with gfile.Open(incomplete_path, 'wb') as f:
    f.write(response.read())
  gfile.Rename(incomplete_path, dst_path)

  return dst_path
Ejemplo n.º 19
0
 def __init__(self, latent_factor_indices=None):
     DSprites.__init__(self, latent_factor_indices)
     self.data_shape = [64, 64, 3]
     with gfile.Open(SCREAM_PATH, "rb") as f:
         scream = PIL.Image.open(f)
         scream.thumbnail((350, 274, 3))
         self.scream = np.array(scream) * 1. / 255.
Ejemplo n.º 20
0
    def _sync_download(self, url, destination_path):
        """Synchronous version of `download` method."""
        checksum = self._checksumer()
        session = requests.Session()
        if _DRIVE_URL.match(url):
            url = self._get_drive_url(url, session)
        response = session.get(url, stream=True)
        if response.status_code != 200:
            raise DownloadError('Failed to get url %s. HTTP code: %d.' %
                                (url, response.status_code))
        fname = _get_filename(response)
        path = os.path.join(destination_path, fname)
        size = 0

        size_mb = 0
        unit_mb = units.MiB
        self._pbar_dl_size.update_total(
            int(response.headers.get('Content-length', 0)) // unit_mb)
        with gfile.Open(path, 'wb') as file_:
            for block in response.iter_content(
                    chunk_size=io.DEFAULT_BUFFER_SIZE):
                size += len(block)

                # Update the progress bar
                size_mb += len(block)
                if size_mb > unit_mb:
                    self._pbar_dl_size.update(size_mb // unit_mb)
                    size_mb %= unit_mb

                checksum.update(block)
                # TODO(pierrot): Test this is faster than doing checksum in the end
                # and document results here.
                file_.write(block)
        self._pbar_url.update(1)
        return checksum.hexdigest(), size
Ejemplo n.º 21
0
def dump_object(object_to_dump, output_path):

    if not gfile.Exists(output_path):
        gfile.MakeDirs(os.path.dirname(output_path))

    with gfile.Open(output_path, 'w') as wf:
        joblib.dump(object_to_dump, wf)
Ejemplo n.º 22
0
def main(argv):
  del argv  # unused.
  if FLAGS.hparams is not None:
    with gfile.Open(FLAGS.hparams, 'r') as f:
      hparams = deep_q_networks_parent.get_hparams(**json.load(f))
  else:
    hparams = deep_q_networks_parent.get_hparams()
  environment = BARewardMolecule(
      discount_factor=hparams.discount_factor,
      atom_types=set(hparams.atom_types),
      init_mol= FLAGS.start_molecule,
      allow_removal=hparams.allow_removal,
      allow_no_modification=hparams.allow_no_modification,
      allow_bonds_between_rings=hparams.allow_bonds_between_rings,
      allowed_ring_sizes=set(hparams.allowed_ring_sizes),
      max_steps=hparams.max_steps_per_episode)

  dqn = deep_q_networks_parent.DeepQNetwork(
      input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
      q_fn=functools.partial(
          deep_q_networks_parent.multi_layer_model, hparams=hparams),
      optimizer=hparams.optimizer,
      grad_clipping=hparams.grad_clipping,
      num_bootstrap_heads=hparams.num_bootstrap_heads,
      gamma=hparams.gamma,
      epsilon=1.0)

  run_dqn_parent.run_training(
      hparams=hparams,
      environment=environment,
      dqn=dqn)

  core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config_sa.json'))
Ejemplo n.º 23
0
async def train(state, selfplay_processes):
    """Run training and write a new model to the fsdb models_dir.

    Args:
        state: the RL loop State instance.
        tf_records: a list of paths to TensorFlow records to train on.
    """

    wait_for_training_examples(state, selfplay_processes, FLAGS.min_games_per_iteration)
    tf_records = await sample_training_examples(state)

    model_path = os.path.join(fsdb.models_dir(), state.train_model_name)

    await run(
        'python3', 'train.py',
        '--gpu_device_list={}'.format(','.join(FLAGS.train_devices)),
        '--flagfile={}'.format(os.path.join(FLAGS.flags_dir, 'train.flags')),
        '--work_dir={}'.format(fsdb.working_dir()),
        '--export_path={}'.format(model_path),
        '--use_extra_features={}'.format(FLAGS.use_extra_features),
        '--freeze=true',
        *tf_records)

    # Append the time elapsed from when the RL was started to when this model
    # was trained.
    elapsed = time.time() - state.start_time
    timestamps_path = os.path.join(fsdb.models_dir(), 'train_times.txt')
    with gfile.Open(timestamps_path, 'a') as f:
        print('{:.3f} {}'.format(elapsed, state.train_model_name), file=f)

    if FLAGS.validate and state.iter_num > 1:
        try:
            await validate(state)
        except Exception as e:
            logging.error(e)
def main(argv):
    del argv
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()

    environment = Molecule(atom_types=set(hparams.atom_types),
                           init_mol=None,
                           allow_removal=hparams.allow_removal,
                           allow_no_modification=hparams.allow_no_modification,
                           max_steps=hparams.max_steps_per_episode)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
Ejemplo n.º 25
0
async def run(*cmd):
    """Run the given subprocess command in a coroutine.

    Args:
        *cmd: the command to run and its arguments.

    Returns:
        The output that the command wrote to stdout as a list of strings, one line
        per element (stderr output is piped to stdout).

    Raises:
        RuntimeError: if the command returns a non-zero result.
    """

    stdout = await checked_run(*cmd)

    log_path = os.path.join(FLAGS.base_dir, get_cmd_name(cmd) + '.log')
    with gfile.Open(log_path, 'a') as f:
        f.write(await expand_cmd_str(cmd))
        f.write('\n')
        f.write(stdout)
        f.write('\n')

    # Split stdout into lines.
    return stdout.split('\n')
Ejemplo n.º 26
0
    def setUpClass(cls):
        cache_dir = tf.test.get_temp_dir()

        # Create a dummy file
        dummy_dir = os.path.join(cache_dir, 'dummy')
        dummy_filepath = os.path.join(dummy_dir, 'dummy.txt')

        gfile.MakeDirs(dummy_dir)
        dummy_file_contents = 'hello world'
        with gfile.Open(dummy_filepath, 'w') as f:
            f.write(dummy_file_contents)

        # File containing compressed archives
        input_dir = os.path.join(cache_dir, 'to_extract')
        gfile.MakeDirs(input_dir)

        dl_manager = download_manager.DownloadManager(
            cache_dir=cache_dir,
            mode=util.GenerateMode.REUSE_CACHE_IF_EXISTS,
        )

        cls.dummy_dir = dummy_dir
        cls.dummy_filepath = dummy_filepath
        cls.dummy_file_contents = dummy_file_contents
        cls.input_dir = input_dir
        cls.dl_manager = dl_manager
Ejemplo n.º 27
0
def read_df_from_gcs(file_pattern):
  """Read data from Google Cloud Storage, split into train and validation sets.

  Assume that the data on GCS is in csv format without header.
  The column names will be provided through metadata

  Args:
    file_pattern: (string) pattern of the files containing training data.
    For example: [gs://bucket/folder_name/prefix]

  Returns:
    pandas.DataFrame
  """

  # Download the files to local /tmp/ folder
  df_list = []

  for filepath in gfile.Glob(file_pattern):
    with gfile.Open(filepath, 'r') as f:
      # Assume there is no header
      df_list.append(pd.read_csv(f, names=metadata.CSV_COLUMNS))

  data_df = pd.concat(df_list)

  return data_df
def get_latest_checkpoint(dirname: str):
    """Get the latest checkpoint in the directory.

  Args:
    dirname: Name of the directory.

  Returns:
    Checkpoint prefix string.
  """
    chkpt_file = os.path.join(dirname, 'checkpoint')
    if not gfile.Exists(chkpt_file):
        logging.info('File %s does not exist', chkpt_file)
        return None
    chkpt_export_folder = os.path.join(dirname, 'export')
    if not gfile.Exists(chkpt_export_folder):
        logging.info('Eval export folder %s does not exist',
                     chkpt_export_folder)
        return None
    num_lines = 0
    with gfile.Open(chkpt_file) as f:
        for l in f:
            num_lines += 1
            if l.startswith(b'model_checkpoint_path:'):
                return os.path.basename(l.strip().split()[1][1:-1])
    return None
Ejemplo n.º 29
0
  def restore_checkpoint(self, path):
    """Restores state from the checkpoint at `path`."""
    self.log_info('Restoring inference checkpoint: %s', path)
    with gfile.Open(path, 'r') as f:
      data = np.load(f)

      self.segmentation[:] = data['segmentation']
      self.seed[:] = data['seed']
      self.seg_prob[:] = data['seg_qprob']
      self.history_deleted = list(data['history_deleted'])
      self.history = list(data['history'])
      self.origins = data['origins'].item()
      if 'overlaps' in data:
        self.overlaps = data['overlaps'].item()

      segmented_voxels = np.sum(self.segmentation != 0)
      self.counters['voxels-segmented'].Set(segmented_voxels)
      self._max_id = np.max(self.segmentation)

      self.movement_policy.restore_state(data['movement_policy'])

      seed_policy_state = data['seed_policy_state']
      # When restoring the state of a previously unused Canvas, the seed
      # policy will not be defined. We just save the seed policy state here
      # for future use in .segment_all().
      self._seed_policy_state = seed_policy_state

      self.counters.loads(data['counters'].item())

    self.log_info('Inference checkpoint restored.')
Ejemplo n.º 30
0
def inference_loop(self, saver, model_ckpt_path):
    output_path = "{}.inference_predicts".format(
        model_ckpt_path.split('/')[-1])
    with tf.Session() as sess, gfile.Open(output_path, "w+") as out_file:
        sess.run(tf.local_variables_initializer())
        saver.restore(sess, model_ckpt_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        num_examples_processed = 0
        start_time = time.time()
        out_file.write("VideoId,LabelConfidencePairs\n")

        try:
            while not coord.should_stop():
                res = sess.run(self.feed_out)
                num_examples_processed += len(res["predictions"].shape[0])
                logging.info(
                    "num examples processed: %d; elapsed seconds: %.2f " %
                    (num_examples_processed, time.time() - start_time))
                for line in format_lines(res["video_id"], res["predictions"],
                                         self.config.top_k):
                    out_file.write(line)
                out_file.flush()
        except tf.errors.OutOfRangeError:
            logging.info(
                'Done with inference. The output file was written to ' +
                output_path)
        finally:
            coord.request_stop()

        coord.join(threads)