def main(_): with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, tpu_address=FLAGS.tpu) max_seq_length = input_meta_data['max_seq_length'] train_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) eval_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.eval_data_path, max_seq_length, FLAGS.eval_batch_size, is_training=False) albert_config = albert_configs.AlbertConfig.from_json_file( FLAGS.bert_config_file) if FLAGS.mode == 'train_and_eval': run_classifier_bert.run_bert(strategy, input_meta_data, albert_config, train_input_fn, eval_input_fn) elif FLAGS.mode == 'predict': predict(strategy, albert_config, input_meta_data, eval_input_fn) else: raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode) return
def main(_): # Users should always run this script under TF 2.x assert tf.version.VERSION.startswith('2.') with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' strategy = distribution_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, tpu_address=FLAGS.tpu) max_seq_length = input_meta_data['max_seq_length'] train_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) eval_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.eval_data_path, max_seq_length, FLAGS.eval_batch_size, is_training=False) albert_config = albert_configs.AlbertConfig.from_json_file( FLAGS.bert_config_file) run_classifier_bert.run_bert(strategy, input_meta_data, albert_config, train_input_fn, eval_input_fn)