def _run_bert_classifier(self, callbacks=None, use_ds=True): """Starts BERT classification task.""" with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) bert_config = modeling.BertConfig.from_json_file( FLAGS.bert_config_file) epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs if self.num_steps_per_epoch: steps_per_epoch = self.num_steps_per_epoch else: train_data_size = input_meta_data['train_data_size'] steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) warmup_steps = int(epochs * steps_per_epoch * 0.1) eval_steps = int( math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) strategy = distribution_utils.get_distribution_strategy( distribution_strategy='mirrored' if use_ds else 'off', num_gpus=self.num_gpus) steps_per_loop = 1 run_classifier.run_bert_classifier(strategy, bert_config, input_meta_data, FLAGS.model_dir, epochs, steps_per_epoch, steps_per_loop, eval_steps, warmup_steps, FLAGS.learning_rate, FLAGS.init_checkpoint, custom_callbacks=callbacks)
def _run_bert_classifier(self, callbacks=None, use_ds=True): """Starts BERT classification task.""" with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs if self.num_steps_per_epoch: steps_per_epoch = self.num_steps_per_epoch else: train_data_size = input_meta_data['train_data_size'] steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) warmup_steps = int(epochs * steps_per_epoch * 0.1) eval_steps = int( math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) if self.tpu: strategy = distribution_utils.get_distribution_strategy( distribution_strategy='tpu', tpu_address=self.tpu) else: strategy = distribution_utils.get_distribution_strategy( distribution_strategy='mirrored' if use_ds else 'off', num_gpus=self.num_gpus) max_seq_length = input_meta_data['max_seq_length'] train_input_fn = run_classifier.get_dataset_fn( FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) eval_input_fn = run_classifier.get_dataset_fn( FLAGS.eval_data_path, max_seq_length, FLAGS.eval_batch_size, is_training=False) _, summary = run_classifier.run_bert_classifier( strategy, bert_config, input_meta_data, FLAGS.model_dir, epochs, steps_per_epoch, FLAGS.steps_per_loop, eval_steps, warmup_steps, FLAGS.learning_rate, FLAGS.init_checkpoint, train_input_fn, eval_input_fn, training_callbacks=False, custom_callbacks=callbacks) return summary