Пример #1
0
def create_labels(input_tfrecord_path,
                  output_tfrecord_path,
                  dataset_preprocess_fn,
                  embedding_fn,
                  label_fn,
                  write_fn=None,
                  batch_size=64,
                  parallel_calls=1):
    """Creates a new set of labels for a single chunk.

  Args:
    input_tfrecord_path: String with input TF Record file.
    output_tfrecord_path: String with input TF Record file.
    dataset_preprocess_fn: Preprocessing function applied to dataset.
    embedding_fn: Embedding function applied to the dataset tensor.
    label_fn: Label function applied to the (after sess.run).
    write_fn: Function to write TF Record to TF Record writer.
    batch_size: Optional integer with batch_size.
  """
    tf.logging.info("Input: {}\nOutput: {}".format(input_tfrecord_path,
                                                   output_tfrecord_path))
    if write_fn is None:
        write_fn = write_imagenet

    if FLAGS.tpu_name:
        cluster = TPUClusterResolver(tpu=[FLAGS.tpu_name])
    else:
        cluster = None
    config = tf.contrib.tpu.RunConfig(cluster=cluster)

    # Load the data in the chunk.
    input_dataset = tf.data.TFRecordDataset(input_tfrecord_path)
    input_dataset = input_dataset.map(dataset_preprocess_fn, parallel_calls)
    input_dataset = input_dataset.batch(batch_size)
    next_node = input_dataset.make_one_shot_iterator().get_next()
    embedding = embedding_fn(next_node)
    with tf.Session(cluster.get_master(),
                    config=config.session_config) as sess:
        with tf.python_io.TFRecordWriter(output_tfrecord_path) as writer:
            sess.run(tf.global_variables_initializer())
            while True:
                try:
                    embedded = sess.run(embedding)
                    results = label_fn(embedded)
                    write_fn(writer, results)
                except tf.errors.OutOfRangeError:
                    break
Пример #2
0
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s :  %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]
log.info("Using TPU runtime")
USE_TPU = True
tpu_cluster_resolver = TPUClusterResolver(tpu='greek-bert',
                                          zone='us-central1-a')

# SETUP FOLDERS
with tf.Session(tpu_cluster_resolver.get_master()) as session:
    print(tpu_cluster_resolver.get_master())
    HOME_PATH = "gs://greek_bert"  # @param {type:"string"}
    MODEL_DIR = "greek_bert"  # @param {type:"string"}
    PRETRAINING_DIR = "greek_tfrecords"  # @param {type:"string"}
    VOC_FNAME = "vocab.txt"  # @param {type:"string"}

# Input data pipeline config
TRAIN_BATCH_SIZE = 256  # @param {type:"integer"}
MAX_PREDICTIONS = 75  # @param {type:"integer"}
MAX_SEQ_LENGTH = 512  # @param {type:"integer"}
MASKED_LM_PROB = 0.15  # @param

# Training procedure config
EVAL_BATCH_SIZE = 256
LEARNING_RATE = 1e-4