def copy_to_gcs(src, dst):
    assert gfile.Exists(src), src
    assert not gfile.Exists(dst), dst

    print("Saving to", dst)
    with gfile.GFile(src, "rb") as src_f, gfile.GFile(dst, "wb") as dst_f:
        shutil.copyfileobj(src_f, dst_f)
Пример #2
0
def launch_eval_job(tag, m1_path, m2_path, job_name, completions):
    """Launches an evaluator job.

  Args:
    tag: name for this eval job (used as top level folder name)
    m1_path, m2_path: full gs:// paths to the .pb files to match up
    job_name: string, appended to the container, used to differentiate the job
      names (e.g. 'minigo-cc-evaluator-v5-123-v7-456')
    completions: the number of completions desired (each completion is 2 games)
  """
    print()
    if not re.match(r'[a-z0-9-]*$', tag, re.I):
        print('{} is not a valid tag'.format(tag))
        return

    # Change to minigo-pub
    sgf_bucket_path = 'sethtroisi-sandbox/experiments/eval/' + tag
    assert not sgf_bucket_path.startswith('gs://'), bucket_pat
    bucket_path = 'gs://' + sgf_bucket_path

    metadata_path = os.path.join(bucket_path, 'metadata')
    assert not gfile.Exists(metadata_path), 'Already exists'

    TS = str(int(time.time()))
    metadata = {
        'timestamp': TS,
        'date': datetime.datetime.now().isoformat(' '),
        'model1': os.path.basename(m1_path),
        'model2': os.path.basename(m2_path),
        'model1_path': m1_path,
        'model2_path': m2_path,
        'job_name': job_name,
        'completions': completions,
        'launch_eval_version': LAUNCH_EVAL_VERSION,
    }

    job_conf, resp_bw, resp_wb = launch_eval.launch_eval_job(
        m1_path, m2_path, job_name, sgf_bucket_path, completions)

    if not (resp_bw and resp_wb):
        print('launch_eval.py failed')
        print(job_conf)
        print(resp_bw)
        print(resp_wb)
        print()
        assert False

    # Jobs were launched, record metadata to GCS.
    with gfile.GFile(metadata_path, 'w') as metadata_file:
        json.dump(metadata, metadata_file)

    with gfile.GFile(os.path.join(bucket_path, 'commands'), 'w') as f:
        f.write(str(sys.argv) + '\n')

    with gfile.GFile(os.path.join(bucket_path, 'job_conf'), 'w') as f:
        f.write(str(job_conf) + '\n')
Пример #3
0
def extract_holdout_model(model):
    game_output_path = OUTPUT_PATH.format(FLAGS.base_dir, 'games', model)
    move_output_path = OUTPUT_PATH.format(FLAGS.base_dir, 'moves', model)
    gfile.MakeDirs(os.path.basename(game_output_path))
    gfile.MakeDirs(os.path.basename(move_output_path))

    with gfile.GFile(game_output_path, 'w') as game_f, \
            gfile.GFile(move_output_path, 'w') as move_f:
        for sgf_name in tqdm(get_sgf_names(model)):
            game_data, move_data = extract_data(sgf_name)
            game_f.write(json.dumps(game_data) + '\n')
            for move_datum in move_data:
                move_f.write(json.dumps(move_datum) + '\n')
Пример #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-t',
        '--tagged-dataset-path',
        default=os.path.join('.', 'tagged', 'data'),
        help='Directory containing CoreNLP-tagged dataset TSV file')
    parser.add_argument('prediction_path',
                        help='Path to the prediction file. Each line contains '
                        'ex_id <tab> item1 <tab> item2 <tab> ...')
    args = parser.parse_args()

    # ID string --> list[Value]
    target_values_map = {}
    for filename in os.listdir(args.tagged_dataset_path):
        filename = os.path.join(args.tagged_dataset_path, filename)
        print >> sys.stderr, 'Reading dataset from', filename
        with codecs.getreader('utf-8')(gfile.GFile(filename, 'r')) as fin:
            header = fin.readline().rstrip('\n').split('\t')
            for line in fin:
                stuff = dict(zip(header, line.rstrip('\n').split('\t')))
                ex_id = stuff['id']
                original_strings = tsv_unescape_list(stuff['targetValue'])
                canon_strings = tsv_unescape_list(stuff['targetCanon'])
                target_values_map[ex_id] = to_value_list(
                    original_strings, canon_strings)
    print >> sys.stderr, 'Read', len(target_values_map), 'examples'

    print >> sys.stderr, 'Reading predictions from', args.prediction_path
    num_examples, num_correct = 0, 0
    with codecs.getreader('utf-8')(gfile.GFile(args.prediction_path,
                                               'r')) as fin:
        for line in fin:
            line = line.rstrip('\n').split('\t')
            ex_id = line[0]
            if ex_id not in target_values_map:
                print 'WARNING: Example ID "%s" not found' % ex_id
            else:
                target_values = target_values_map[ex_id]
                predicted_values = to_value_list(line[1:])
                correct = check_denotation(target_values, predicted_values)
                print u'%s\t%s\t%s\t%s' % (ex_id, correct, target_values,
                                           predicted_values)
                num_examples += 1
                if correct:
                    num_correct += 1
    print >> sys.stderr, 'Examples:', num_examples
    print >> sys.stderr, 'Correct:', num_correct
    print >> sys.stderr, 'Accuracy:', round(
        (num_correct + 1e-9) / (num_examples + 1e-9), 4)
Пример #5
0
 def _preprocess_rsa_psi_follower(self):
     processors = []
     rsa_key_pem = None
     with gfile.GFile(self._rsa_public_key_path, 'rb') as f:
         rsa_key_pem = f.read()
     for partition_id in range(
             self._data_source_f.data_source_meta.partition_num):
         options = dj_pb.RsaPsiPreProcessorOptions(
             preprocessor_name='follower-rsa-psi-processor',
             role=common_pb.FLRole.Follower,
             rsa_key_pem=rsa_key_pem,
             input_file_paths=[self._psi_raw_data_fpaths_f[partition_id]],
             output_file_dir=self._pre_processor_ouput_dir_f,
             raw_data_publish_dir=self._raw_data_pub_dir_f,
             partition_id=partition_id,
             leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr,
             offload_processor_number=1,
             batch_processor_options=dj_pb.BatchProcessorOptions(
                 batch_size=1024, max_flying_item=1 << 14))
         processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
             options, self._etcd_name, self._etcd_addrs,
             self._etcd_base_dir_f, True)
         processor.start_process()
         processors.append(processor)
     for processor in processors:
         processor.wait_for_finished()
Пример #6
0
def load_spider_examples(filenames):
    """Loads examples from the Spider dataset from the specified files."""
    examples = []
    for filename in filenames.split(','):
        with gfile.GFile(filename) as training_file:
            examples += json.load(training_file)
    return examples
Пример #7
0
 def _preprocess_rsa_psi_follower(self):
     processors = []
     rsa_key_pem = None
     with gfile.GFile(self._rsa_public_key_path, 'rb') as f:
         rsa_key_pem = f.read()
     for partition_id in range(
             self._data_source_f.data_source_meta.partition_num):
         options = dj_pb.RsaPsiPreProcessorOptions(
             preprocessor_name='follower-rsa-psi-processor',
             role=common_pb.FLRole.Follower,
             rsa_key_pem=rsa_key_pem,
             input_file_paths=[self._psi_raw_data_fpaths_f[partition_id]],
             output_file_dir=self._pre_processor_ouput_dir_f,
             raw_data_publish_dir=self._raw_data_pub_dir_f,
             partition_id=partition_id,
             leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr,
             offload_processor_number=1,
             max_flying_sign_batch=128,
             max_flying_sign_rpc=64,
             sign_rpc_timeout_ms=100000,
             stub_fanout=2,
             slow_sign_threshold=8,
             sort_run_merger_read_ahead_buffer=1 << 20,
             rpc_sync_mode=True if partition_id % 2 == 0 else False,
             rpc_thread_pool_size=16,
             batch_processor_options=dj_pb.BatchProcessorOptions(
                 batch_size=1024, max_flying_item=1 << 14))
         processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
             options, self._etcd_name, self._etcd_addrs,
             self._etcd_base_dir_f, True)
         processor.start_process()
         processors.append(processor)
     for processor in processors:
         processor.wait_for_finished()
Пример #8
0
 def _preprocess_rsa_psi_leader(self):
     processors = []
     rsa_key_pem = None
     with gfile.GFile(self._rsa_private_key_path, 'rb') as f:
         rsa_key_pem = f.read()
     for partition_id in range(
             self._data_source_l.data_source_meta.partition_num):
         options = dj_pb.RsaPsiPreProcessorOptions(
             preprocessor_name='leader-rsa-psi-processor',
             role=common_pb.FLRole.Leader,
             rsa_key_pem=rsa_key_pem,
             input_file_paths=[self._psi_raw_data_fpaths_l[partition_id]],
             output_file_dir=self._pre_processor_ouput_dir_l,
             raw_data_publish_dir=self._raw_data_pub_dir_l,
             partition_id=partition_id,
             offload_processor_number=1,
             max_flying_sign_batch=128,
             stub_fanout=2,
             slow_sign_threshold=8,
             sort_run_merger_read_ahead_buffer=1 << 20,
             batch_processor_options=dj_pb.BatchProcessorOptions(
                 batch_size=1024, max_flying_item=1 << 14))
         processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
             options, self._etcd_name, self._etcd_addrs,
             self._etcd_base_dir_l, True)
         processor.start_process()
         processors.append(processor)
     for processor in processors:
         processor.wait_for_finished()
 def finish_map(self):
     for writer in self._writers:
         writer.finish()
     succ_tag_fpath = common.encode_portal_hourly_finish_tag(
         self._potral_manifest.output_data_base_dir, self._date_time)
     with gfile.GFile(succ_tag_fpath, 'w') as fh:
         fh.write('')
Пример #10
0
def load_spider_tables(filenames):
    """Loads database schemas from the specified filenames."""
    examples = dict()
    for filename in filenames.split(','):
        with gfile.GFile(filename) as training_file:
            examples.update(process_dbs(json.load(training_file)))
    return examples
Пример #11
0
 def __init__(self, rsa_private_key_path, offload_processor_number):
     with gfile.GFile(rsa_private_key_path, 'rb') as f:
         file_content = f.read()
         self._prv_key = rsa.PrivateKey.load_pkcs1(file_content)
     self._process_pool_executor = None
     if offload_processor_number > 0:
         self._process_pool_executor = \
                 futures.ProcessPoolExecutor(offload_processor_number)
Пример #12
0
def run_game(load_file,
             selfplay_dir=None,
             holdout_dir=None,
             sgf_dir=None,
             holdout_pct=0.05):
    """Takes a played game and record results and game data."""
    if sgf_dir is not None:
        minimal_sgf_dir = os.path.join(sgf_dir, 'clean')
        full_sgf_dir = os.path.join(sgf_dir, 'full')
        utils.ensure_dir_exists(minimal_sgf_dir)
        utils.ensure_dir_exists(full_sgf_dir)
    if selfplay_dir is not None:
        utils.ensure_dir_exists(selfplay_dir)
        utils.ensure_dir_exists(holdout_dir)

    with utils.logged_timer('Loading weights from %s ... ' % load_file):
        network = dual_net.DualNetwork(load_file)

    with utils.logged_timer('Playing game'):
        player = play(network)

    output_name = '{}-{}'.format(int(time.time()), socket.gethostname())
    game_data = player.extract_data()
    if sgf_dir is not None:
        with gfile.GFile(
                os.path.join(minimal_sgf_dir, '{}.sgf'.format(output_name)),
                'w') as f:
            f.write(player.to_sgf(use_comments=False))
        with gfile.GFile(
                os.path.join(full_sgf_dir, '{}.sgf'.format(output_name)),
                'w') as f:
            f.write(player.to_sgf())

    tf_examples = preprocessing.make_dataset_from_selfplay(game_data)

    if selfplay_dir is not None:
        # Hold out 5% of games for validation.
        if random.random() < holdout_pct:
            fname = os.path.join(holdout_dir,
                                 '{}.tfrecord.zz'.format(output_name))
        else:
            fname = os.path.join(selfplay_dir,
                                 '{}.tfrecord.zz'.format(output_name))

        preprocessing.write_tf_examples(fname, tf_examples)
Пример #13
0
 def _launch_rsa_psi_signer(self):
     self._rsa_psi_signer_addr = 'localhost:6171'
     rsa_private_key_pem = None
     with gfile.GFile(self._rsa_private_key_path, 'rb') as f:
         rsa_private_key_pem = f.read()
     rsa_private_key = rsa.PrivateKey.load_pkcs1(rsa_private_key_pem)
     self._rsa_psi_signer = rsa_psi_signer.RsaPsiSigner(rsa_private_key, 1)
     self._rsa_psi_signer.start(int(
         self._rsa_psi_signer_addr.split(':')[1]))
Пример #14
0
def extract_data(filename):
    with gfile.GFile(filename) as f:
        contents = f.read()
    root_node = sgf_wrapper.get_sgf_root_node(contents)
    game_data = extract_game_data(filename, root_node)
    move_data = extract_move_data(root_node, game_data['worker_id'],
                                  game_data['completed_time'],
                                  game_data['board_size'])
    return game_data, move_data
Пример #15
0
 def finish(self):
     self._finish_csv_dict_writer()
     logging.warning("merge %d record in %d sort run merger: "\
                     "for partition %d", self._merged_num,
                     self._process_index, self._partition_id)
     finish_tag_fpath = os.path.join(self._merged_dir, '_SUCCESS')
     with gfile.GFile(finish_tag_fpath, 'w') as fh:
         fh.write('\n')
     return self._merged_fpaths
Пример #16
0
def _get_vocab_symbols(filename):
    """Returns a list of symbols in a vocabularly file."""
    vocab = []
    if not gfile.Exists(filename):
        raise ValueError("File does not exist: {}".format(filename))
    with gfile.GFile(filename) as fp:
        for line in fp:
            vocab.append(line.rstrip("\n"))
    return vocab
Пример #17
0
 def __init__(self,
              options,
              etcd_name,
              etcd_addrs,
              etcd_base_dir,
              use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(etcd, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._id_batch_fetcher = IdBatchFetcher(self._options)
     max_flying_item = options.batch_processor_options.max_flying_item
     if self._options.role == common_pb.FLRole.Leader:
         private_key = None
         with gfile.GFile(options.rsa_key_file_path, 'rb') as f:
             file_content = f.read()
             private_key = rsa.PrivateKey.load_pkcs1(file_content)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             max_flying_item,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = None
         with gfile.GFile(options.rsa_key_file_path, 'rb') as f:
             file_content = f.read()
             public_key = rsa.PublicKey.load_pkcs1(file_content)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher, max_flying_item,
             self._process_pool_executor, public_key,
             self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
         self._sort_run_dumper.sort_run_dump_dir, self._options)
     self._worker_map = {}
     self._started = False
 def _generate_input_data(self):
     self._partition_item_num = 1 << 16
     self._clean_up()
     gfile.MakeDirs(self._input_dir)
     success_flag_fpath = "{}/_SUCCESS".format(self._input_dir)
     example_id = 1000001
     for partition_id in range(self._input_partition_num):
         example_id = self._generate_one_partition(partition_id, example_id, self._partition_item_num)
     
     with gfile.GFile(success_flag_fpath, 'w') as fh:
         fh.write('')
Пример #19
0
def from_json_file(path):
  """Reads from a json file.

  Args:
    path: Path to read from.

  Returns:
    The object read from the json file.
  """
  with gfile.GFile(str(path)) as fin:
    return json.loads(fin.read())
Пример #20
0
def to_json_file(path, obj, indent = 4):
  """Saves to a json file.

  Args:
    path: Where to save.
    obj: The object to save

  Returns:
    None
  """
  with gfile.GFile(str(path), "w") as fout:
    fout.write(json.dumps(obj, indent=indent))
Пример #21
0
def to_json_file(path: PathType, obj, indent : int = 4):
  """Saves to a json file.

  Args:
    path: Where to save.
    obj: The object to save
    indent: Number of spaces to use as indentation.

  Returns:
    None
  """
  with gfile.GFile(str(path), "w") as fout:
    fout.write(json.dumps(obj, indent=indent))
Пример #22
0
def load_contents(file_path):
  """Load contents from pickle file.

  Args:
    file_path (string): location of file to load.

  Returns:
    contents: contents from pickle file.
  """
  with gfile.GFile(file_path, mode="r") as f:
    contents = f.read()
    contents = pickle.loads(contents)
  return contents
Пример #23
0
def ablation_visualization(
    x1, x2, gen, z_dim, basedir, global_step, figsize=(20, 20), show=False):
  images = generate(x1, x2, gen, z_dim, 3, 12)
  plt.figure(figsize=figsize)
  plt.imshow(grid(images, 12, 1), cmap='Greys_r', interpolation=None)
  plt.axis('off')
  if show:
    plt.show()

  filename = os.path.join(basedir, 'ablation_{:09d}.png'.format(global_step))
  with gfile.GFile(filename, mode='w') as f:
    plt.savefig(f, dpi=100, bbox_inches='tight')

  plt.close()
 def _generate_portal_input_data(self, date_time, event_time_filter,
                                 start_index, total_item_num,
                                 portal_manifest):
     self.assertEqual(total_item_num % portal_manifest.input_partition_num,
                      0)
     item_step = portal_manifest.input_partition_num
     for partition_id in range(portal_manifest.input_partition_num):
         cands = list(range(partition_id, total_item_num, item_step))
         for i in range(len(cands)):
             if random.randint(1, 4) > 1:
                 continue
             a = random.randint(i - 16, i + 16)
             b = random.randint(i - 16, i + 16)
             if a < 0:
                 a = 0
             if a >= len(cands):
                 a = len(cands) - 1
             if b < 0:
                 b = 0
             if b >= len(cands):
                 b = len(cands) - 1
             if abs(cands[a] // item_step -
                    b) <= 16 and abs(cands[b] // item_step - a) <= 16:
                 cands[a], cands[b] = cands[b], cands[a]
         fpath = common.encode_portal_hourly_fpath(
             portal_manifest.input_data_base_dir, date_time, partition_id)
         if not gfile.Exists(os.path.dirname(fpath)):
             gfile.MakeDirs(os.path.dirname(fpath))
         with tf.io.TFRecordWriter(fpath) as writer:
             for lid in cands:
                 real_id = lid + start_index
                 feat = {}
                 example_id = '{}'.format(real_id).encode()
                 feat['example_id'] = tf.train.Feature(
                     bytes_list=tf.train.BytesList(value=[example_id]))
                 # if test the basic example_validator for invalid event time
                 if real_id == 0 or not event_time_filter(real_id):
                     event_time = 150000000 + real_id
                     feat['event_time'] = tf.train.Feature(
                         int64_list=tf.train.Int64List(value=[event_time]))
                 example = tf.train.Example(features=tf.train.Features(
                     feature=feat))
                 writer.write(example.SerializeToString())
     succ_tag_fpath = common.encode_portal_hourly_finish_tag(
         portal_manifest.input_data_base_dir, date_time)
     with gfile.GFile(succ_tag_fpath, 'w') as fh:
         fh.write('')
Пример #25
0
def read_dataframe_from_hdf5(path, key='data'):
    """Read a DataFrame from the given HDF5 file.

  Args:
    path: string path where the DataFrame is saved.
    key: optional string name for the DataFrame in the HDF5 file.

  Returns:
    pandas.DataFrame loaded from the HDF5 file.
  """
    with gfile.GFile(path, 'rb') as f:
        with pandas.HDFStore('in_memory',
                             mode='r',
                             driver='H5FD_CORE',
                             driver_core_backing_store=0,
                             driver_core_image=f.read()) as store:
            return store[key]
Пример #26
0
 def finish(self):
     self._csv_dict_writer.close()
     if self._csv_dict_writer.write_raw_num() == 0:
         logging.warning("no record in sort run merger %s at" \
                         "partition %d. reomve the tmp file %s" \
                         "create finish tag", self._fpath,
                         self._partition_id, self._tmp_fpath)
         gfile.Remove(self._tmp_fpath)
         finish_tag_fpath = os.path.join(self._get_output_dir(),
                                         '_SUCCESS')
         with gfile.GFile(finish_tag_fpath, 'w') as fh:
             fh.write('')
     else:
         gfile.Rename(self._tmp_fpath, self._fpath, True)
         logging.warning("dump %d record in sort run merger: "\
                         "%s at partition %d",
                         self._csv_dict_writer.write_raw_num(),
                         self._fpath, self._partition_id)
Пример #27
0
 def _generate_input_data(self):
     self._total_item_num = 1 << 16
     self.assertEqual(
         self._total_item_num % self._portal_manifest.input_partition_num,
         0)
     if gfile.Exists(self._portal_manifest.input_data_base_dir):
         gfile.DeleteRecursively(self._portal_manifest.input_data_base_dir)
     if gfile.Exists(self._portal_manifest.output_data_base_dir):
         gfile.DeleteRecursively(self._portal_manifest.output_data_base_dir)
     hourly_dir = common.encode_portal_hourly_dir(
         self._portal_manifest.input_data_base_dir, self._date_time)
     gfile.MakeDirs(hourly_dir)
     for partition_id in range(self._portal_manifest.input_partition_num):
         self._generate_one_part(partition_id)
     succ_tag_fpath = common.encode_portal_hourly_finish_tag(
         self._portal_manifest.input_data_base_dir, self._date_time)
     with gfile.GFile(succ_tag_fpath, 'w') as fh:
         fh.write('')
Пример #28
0
 def _preprocess_rsa_psi_follower(self):
     processors = []
     rsa_key_pem = None
     with gfile.GFile(self._rsa_public_key_path, 'rb') as f:
         rsa_key_pem = f.read()
     self._follower_rsa_psi_sub_dir = 'follower_rsa_psi_sub_dir'
     rd_publisher = raw_data_publisher.RawDataPublisher(
         self._kvstore_f, self._follower_rsa_psi_sub_dir)
     for partition_id in range(
             self._data_source_f.data_source_meta.partition_num):
         rd_publisher.publish_raw_data(
             partition_id, [self._psi_raw_data_fpaths_f[partition_id]])
         rd_publisher.finish_raw_data(partition_id)
         options = dj_pb.RsaPsiPreProcessorOptions(
             preprocessor_name='follower-rsa-psi-processor',
             role=common_pb.FLRole.Follower,
             rsa_key_pem=rsa_key_pem,
             input_file_subscribe_dir=self._follower_rsa_psi_sub_dir,
             output_file_dir=self._pre_processor_ouput_dir_f,
             raw_data_publish_dir=self._raw_data_pub_dir_f,
             partition_id=partition_id,
             leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr,
             offload_processor_number=1,
             max_flying_sign_batch=128,
             max_flying_sign_rpc=64,
             sign_rpc_timeout_ms=100000,
             stub_fanout=2,
             slow_sign_threshold=8,
             sort_run_merger_read_ahead_buffer=1 << 20,
             sort_run_merger_read_batch_size=128,
             batch_processor_options=dj_pb.BatchProcessorOptions(
                 batch_size=1024, max_flying_item=1 << 14),
             input_raw_data=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                 read_ahead_size=1 << 20),
             writer_options=dj_pb.WriterOptions(output_writer='CSV_DICT'))
         os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
         processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
             options, self.kvstore_type, True)
         processor.start_process()
         processors.append(processor)
     for processor in processors:
         processor.wait_for_finished()
def load_model(model):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with gfile.GFile(model_exp, 'rb') as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')
    else:
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)

        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)

        saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file))
        saver.restore(tf.get_default_session(),
                      os.path.join(model_exp, ckpt_file))
Пример #30
0
def write_dataframe_to_hdf5(df, path, complib='zlib', complevel=5, key='data'):
    """Write a DataFrame to the given path as an HDF5 file.

  Args:
    df: pandas.DataFrame to save.
    path: string path to which to save the path.
    complib: optional string giving the compression library to use.
    complevel: optional integer giving the desired level of compression.
    key: optional string name for the DataFrame in the HDF5 file.
  """
    if not isinstance(df, pandas.DataFrame):
        raise TypeError('write_dataframe_to_hdf5 input must be a DataFrame.')
    with pandas.HDFStore('in_memory',
                         mode='w',
                         complib=complib,
                         complevel=complevel,
                         driver='H5FD_CORE',
                         driver_core_backing_store=0) as store:
        store[key] = df
        # pylint: disable=protected-access
        buf = store._handle.get_file_image()
        with gfile.GFile(path, 'wb') as f:
            f.write(buf)