def main(_):
    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.train_data_path,
                                                        max_seq_length,
                                                        FLAGS.train_batch_size,
                                                        is_training=True)
    eval_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.eval_data_path,
                                                       max_seq_length,
                                                       FLAGS.eval_batch_size,
                                                       is_training=False)

    albert_config = albert_configs.AlbertConfig.from_json_file(
        FLAGS.bert_config_file)
    if FLAGS.mode == 'train_and_eval':
        run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
                                     train_input_fn, eval_input_fn)
    elif FLAGS.mode == 'predict':
        predict(strategy, albert_config, input_meta_data, eval_input_fn)
    else:
        raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
    return
예제 #2
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.train_data_path,
                                                        max_seq_length,
                                                        FLAGS.train_batch_size,
                                                        is_training=True)
    eval_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.eval_data_path,
                                                       max_seq_length,
                                                       FLAGS.eval_batch_size,
                                                       is_training=False)

    albert_config = albert_configs.AlbertConfig.from_json_file(
        FLAGS.bert_config_file)
    run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
                                 train_input_fn, eval_input_fn)