def main(_): """ Builds the model and runs. """ tf.logging.set_verbosity(tf.logging.INFO) tx.utils.maybe_create_dir(FLAGS.output_dir) bert_pretrain_dir = 'bert_pretrained_models/%s' % FLAGS.config_bert_pretrain # Loads BERT model configuration if FLAGS.config_format_bert == "json": bert_config = model_utils.transform_bert_to_texar_config( os.path.join(bert_pretrain_dir, 'bert_config.json')) elif FLAGS.config_format_bert == 'texar': bert_config = importlib.import_module( 'bert_config_lib.config_model_%s' % FLAGS.config_bert_pretrain) else: raise ValueError('Unknown config_format_bert.') # Loads data processors = { "cola": data_utils.ColaProcessor, "mnli": data_utils.MnliProcessor, "mrpc": data_utils.MrpcProcessor, "xnli": data_utils.XnliProcessor, 'sst': data_utils.SSTProcessor } processor = processors[FLAGS.task.lower()]() num_classes = len(processor.get_labels()) num_train_data = len(processor.get_train_examples(config_data.data_dir)) tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join( bert_pretrain_dir, 'vocab.txt'), do_lower_case=FLAGS.do_lower_case) train_dataset = data_utils.get_dataset(processor, tokenizer, config_data.data_dir, config_data.max_seq_length, config_data.train_batch_size, mode='train', output_dir=FLAGS.output_dir) eval_dataset = data_utils.get_dataset(processor, tokenizer, config_data.data_dir, config_data.max_seq_length, config_data.eval_batch_size, mode='eval', output_dir=FLAGS.output_dir) test_dataset = data_utils.get_dataset(processor, tokenizer, config_data.data_dir, config_data.max_seq_length, config_data.test_batch_size, mode='test', output_dir=FLAGS.output_dir) iterator = tx.data.FeedableDataIterator({ 'train': train_dataset, 'eval': eval_dataset, 'test': test_dataset }) batch = iterator.get_next() input_ids = batch["input_ids"] segment_ids = batch["segment_ids"] batch_size = tf.shape(input_ids)[0] input_length = tf.reduce_sum(1 - tf.to_int32(tf.equal(input_ids, 0)), axis=1) # Builds BERT with tf.variable_scope('bert'): embedder = tx.modules.WordEmbedder(vocab_size=bert_config.vocab_size, hparams=bert_config.embed) word_embeds = embedder(input_ids) # Creates segment embeddings for each type of tokens. segment_embedder = tx.modules.WordEmbedder( vocab_size=bert_config.type_vocab_size, hparams=bert_config.segment_embed) segment_embeds = segment_embedder(segment_ids) input_embeds = word_embeds + segment_embeds # The BERT model (a TransformerEncoder) encoder = tx.modules.TransformerEncoder(hparams=bert_config.encoder) output = encoder(input_embeds, input_length) # Builds layers for downstream classification, which is also initialized # with BERT pre-trained checkpoint. with tf.variable_scope("pooler"): # Uses the projection of the 1st-step hidden vector of BERT output # as the representation of the sentence bert_sent_hidden = tf.squeeze(output[:, 0:1, :], axis=1) bert_sent_output = tf.layers.dense(bert_sent_hidden, config_downstream.hidden_dim, activation=tf.tanh) output = tf.layers.dropout(bert_sent_output, rate=0.1, training=tx.global_mode_train()) # Adds the final classification layer logits = tf.layers.dense( output, num_classes, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)) preds = tf.argmax(logits, axis=-1, output_type=tf.int32) accu = tx.evals.accuracy(batch['label_ids'], preds) # Optimization loss = tf.losses.sparse_softmax_cross_entropy(labels=batch["label_ids"], logits=logits) global_step = tf.Variable(0, trainable=False) # Builds learning rate decay scheduler static_lr = config_downstream.lr['static_lr'] num_train_steps = int(num_train_data / config_data.train_batch_size * config_data.max_train_epoch) num_warmup_steps = int(num_train_steps * config_data.warmup_proportion) lr = model_utils.get_lr( global_step, num_train_steps, # lr is a Tensor num_warmup_steps, static_lr) train_op = tx.core.get_train_op(loss, global_step=global_step, learning_rate=lr, hparams=config_downstream.opt) # Train/eval/test routine def _run(sess, mode): fetches = { 'accu': accu, 'batch_size': batch_size, 'step': global_step, 'loss': loss, } if mode == 'train': fetches['train_op'] = train_op while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } rets = sess.run(fetches, feed_dict) if rets['step'] % 50 == 0: tf.logging.info('step:%d loss:%f' % (rets['step'], rets['loss'])) if rets['step'] == num_train_steps: break except tf.errors.OutOfRangeError: break if mode == 'eval': cum_acc = 0.0 nsamples = 0 while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'eval'), tx.context.global_mode(): tf.estimator.ModeKeys.EVAL, } rets = sess.run(fetches, feed_dict) cum_acc += rets['accu'] * rets['batch_size'] nsamples += rets['batch_size'] except tf.errors.OutOfRangeError: break tf.logging.info('dev accu: {}'.format(cum_acc / nsamples)) if mode == 'test': _all_preds = [] while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'test'), tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT, } _preds = sess.run(preds, feed_dict=feed_dict) _all_preds.extend(_preds.tolist()) except tf.errors.OutOfRangeError: break output_file = os.path.join(FLAGS.output_dir, "test_results.tsv") with tf.gfile.GFile(output_file, "w") as writer: writer.write('\n'.join(str(p) for p in _all_preds)) with tf.Session() as sess: # Loads pretrained BERT model parameters init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt') model_utils.init_bert_checkpoint(init_checkpoint) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) # Restores trained model if specified saver = tf.train.Saver() if FLAGS.checkpoint: saver.restore(sess, FLAGS.checkpoint) iterator.initialize_dataset(sess) if FLAGS.do_train: iterator.restart_dataset(sess, 'train') _run(sess, mode='train') saver.save(sess, FLAGS.output_dir + '/model.ckpt') if FLAGS.do_eval: iterator.restart_dataset(sess, 'eval') _run(sess, mode='eval') if FLAGS.do_test: iterator.restart_dataset(sess, 'test') _run(sess, mode='test')
def main(): """ Builds the model and runs. """ tx.utils.maybe_create_dir(args.output_dir) # Loads data num_train_data = config_data.num_train_data # Builds BERT bert_pretrain_dir = f'bert_pretrained_models/{args.config_bert_pretrain}' if args.config_format_bert == "json": bert_config = model_utils.transform_bert_to_texar_config( os.path.join(bert_pretrain_dir, 'bert_config.json')) elif args.config_format_bert == 'texar': bert_config = importlib.import_module( f'bert_config_lib.config_model_{args.config_bert_pretrain}') else: raise ValueError('Unknown config_format_bert.') bert_hparams = BertClassifier.default_hparams() for key in bert_config.keys(): bert_hparams[key] = bert_config[key] for key in config_downstream.keys(): bert_hparams[key] = config_downstream[key] model = BertClassifier(hparams=bert_hparams) init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt') model_utils.init_bert_checkpoint(model, init_checkpoint) if torch.cuda.is_available(): model = model.cuda() print(f"Pretrained model loaded from {init_checkpoint}") # Builds learning rate decay scheduler static_lr = 2e-5 num_train_steps = int(num_train_data / config_data.train_batch_size * config_data.max_train_epoch) num_warmup_steps = int(num_train_steps * config_data.warmup_proportion) vars_with_decay = [] vars_without_decay = [] for name, param in model.named_parameters(): if 'layer_norm' in name or name.endswith('bias'): vars_without_decay.append(param) else: vars_with_decay.append(param) opt_params = [{ 'params': vars_with_decay, 'weight_decay': 0.01, }, { 'params': vars_without_decay, 'weight_decay': 0.0, }] optim = BertAdam(opt_params, betas=(0.9, 0.999), eps=1e-6, lr=static_lr) scheduler = torch.optim.lr_scheduler.LambdaLR( optim, functools.partial(model_utils.get_lr_multiplier, total_steps=num_train_steps, warmup_steps=num_warmup_steps)) train_dataset = tx.data.RecordData(hparams=config_data.train_hparam) eval_dataset = tx.data.RecordData(hparams=config_data.eval_hparam) test_dataset = tx.data.RecordData(hparams=config_data.test_hparam) iterator = tx.data.DataIterator({ "train": train_dataset, "eval": eval_dataset, "test": test_dataset }) def _train_epoch(): """Trains on the training set, and evaluates on the dev set periodically. """ iterator.switch_to_dataset("train") model.train() for batch in iterator: optim.zero_grad() input_ids = batch["input_ids"] segment_ids = batch["segment_ids"] labels = batch["label_ids"] if torch.cuda.is_available(): input_ids = input_ids.cuda() segment_ids = segment_ids.cuda() labels = labels.cuda() input_length = (1 - (input_ids == 0).int()).sum(dim=1) _, _, loss = model(inputs=input_ids, sequence_length=input_length, segment_ids=segment_ids, labels=labels) loss.backward() optim.step() scheduler.step() step = scheduler.last_epoch dis_steps = config_data.display_steps if dis_steps > 0 and step % dis_steps == 0: logging.info("step: %d; loss: %d", step, loss) eval_steps = config_data.eval_steps if eval_steps > 0 and step % eval_steps == 0: _eval_epoch() @torch.no_grad() def _eval_epoch(): """Evaluates on the dev set. """ iterator.switch_to_dataset("eval") model.eval() cum_acc = 0.0 cum_loss = 0.0 nsamples = 0 for batch in iterator: input_ids = batch["input_ids"] segment_ids = batch["segment_ids"] labels = batch["label_ids"] if torch.cuda.is_available(): input_ids = input_ids.cuda() segment_ids = segment_ids.cuda() labels = labels.cuda() batch_size = input_ids.size()[0] input_length = (1 - (input_ids == 0).int()).sum(dim=1) _, preds, loss = model( inputs=input_ids, sequence_length=input_length, segment_ids=segment_ids, labels=labels, ) accu = tx.evals.accuracy(labels, preds) cum_acc += accu * batch_size cum_loss += loss * batch_size nsamples += batch_size logging.info("eval accu: %.4f; loss: %.4f; nsamples: %d", cum_acc / nsamples, cum_loss / nsamples, nsamples) @torch.no_grad() def _test_epoch(): """Does predictions on the test set. """ iterator.switch_to_dataset("test") model.eval() _all_preds = [] for batch in iterator: input_ids = batch["input_ids"] segment_ids = batch["segment_ids"] if torch.cuda.is_available(): input_ids = input_ids.cuda() segment_ids = segment_ids.cuda() input_length = (1 - (input_ids == 0).int()).sum(dim=1) _, preds, _ = model( inputs=input_ids, sequence_length=input_length, segment_ids=segment_ids, ) _all_preds.extend(preds.tolist()) output_file = os.path.join(args.output_dir, "test_results.tsv") with open(output_file, "w+") as writer: writer.write("\n".join(str(p) for p in _all_preds)) if args.checkpoint: ckpt = torch.load(args.checkpoint) model.load_state_dict(ckpt['model']) optim.load_state_dict(ckpt['optimizer']) scheduler.load_state_dict(ckpt['scheduler']) if args.do_train: for _ in range(config_data.max_train_epoch): _train_epoch() states = { 'model': model.state_dict(), 'optimizer': optim.state_dict(), 'scheduler': scheduler.state_dict(), } torch.save(states, os.path.join(args.output_dir + '/model.ckpt')) if args.do_eval: _eval_epoch() if args.do_test: _test_epoch()
def main(_): """ Builds the model and runs. """ if FLAGS.distributed: import horovod.tensorflow as hvd hvd.init() tf.logging.set_verbosity(tf.logging.INFO) tx.utils.maybe_create_dir(FLAGS.output_dir) bert_pretrain_dir = ('bert_pretrained_models' '/%s') % FLAGS.config_bert_pretrain # Loads BERT model configuration if FLAGS.config_format_bert == "json": bert_config = model_utils.transform_bert_to_texar_config( os.path.join(bert_pretrain_dir, 'bert_config.json')) elif FLAGS.config_format_bert == 'texar': bert_config = importlib.import_module( ('bert_config_lib.' 'config_model_%s') % FLAGS.config_bert_pretrain) else: raise ValueError('Unknown config_format_bert.') # Loads data num_classes = config_data.num_classes num_train_data = config_data.num_train_data # Configures distribued mode if FLAGS.distributed: config_data.train_hparam["dataset"]["num_shards"] = hvd.size() config_data.train_hparam["dataset"]["shard_id"] = hvd.rank() config_data.train_hparam["batch_size"] //= hvd.size() train_dataset = tx.data.TFRecordData(hparams=config_data.train_hparam) eval_dataset = tx.data.TFRecordData(hparams=config_data.eval_hparam) test_dataset = tx.data.TFRecordData(hparams=config_data.test_hparam) iterator = tx.data.FeedableDataIterator({ 'train': train_dataset, 'eval': eval_dataset, 'test': test_dataset }) batch = iterator.get_next() input_ids = batch["input_ids"] segment_ids = batch["segment_ids"] batch_size = tf.shape(input_ids)[0] input_length = tf.reduce_sum(1 - tf.cast(tf.equal(input_ids, 0), tf.int32), axis=1) # Builds BERT with tf.variable_scope('bert'): # Word embedding embedder = tx.modules.WordEmbedder(vocab_size=bert_config.vocab_size, hparams=bert_config.embed) word_embeds = embedder(input_ids) # Segment embedding for each type of tokens segment_embedder = tx.modules.WordEmbedder( vocab_size=bert_config.type_vocab_size, hparams=bert_config.segment_embed) segment_embeds = segment_embedder(segment_ids) # Position embedding position_embedder = tx.modules.PositionEmbedder( position_size=bert_config.position_size, hparams=bert_config.position_embed) seq_length = tf.ones([batch_size], tf.int32) * tf.shape(input_ids)[1] pos_embeds = position_embedder(sequence_length=seq_length) # Aggregates embeddings input_embeds = word_embeds + segment_embeds + pos_embeds # The BERT model (a TransformerEncoder) encoder = tx.modules.TransformerEncoder(hparams=bert_config.encoder) output = encoder(input_embeds, input_length) # Builds layers for downstream classification, which is also # initialized with BERT pre-trained checkpoint. with tf.variable_scope("pooler"): # Uses the projection of the 1st-step hidden vector of BERT output # as the representation of the sentence bert_sent_hidden = tf.squeeze(output[:, 0:1, :], axis=1) bert_sent_output = tf.layers.dense(bert_sent_hidden, config_downstream.hidden_dim, activation=tf.tanh) output = tf.layers.dropout(bert_sent_output, rate=0.1, training=tx.global_mode_train()) # Adds the final classification layer logits = tf.layers.dense( output, num_classes, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)) preds = tf.argmax(logits, axis=-1, output_type=tf.int32) accu = tx.evals.accuracy(batch['label_ids'], preds) # Optimization loss = tf.losses.sparse_softmax_cross_entropy(labels=batch["label_ids"], logits=logits) global_step = tf.Variable(0, trainable=False) # Builds learning rate decay scheduler static_lr = config_downstream.lr['static_lr'] num_train_steps = int(num_train_data / config_data.train_batch_size * config_data.max_train_epoch) num_warmup_steps = int(num_train_steps * config_data.warmup_proportion) lr = model_utils.get_lr( global_step, num_train_steps, # lr is a Tensor num_warmup_steps, static_lr) opt = tx.core.get_optimizer(global_step=global_step, learning_rate=lr, hparams=config_downstream.opt) if FLAGS.distributed: opt = hvd.DistributedOptimizer(opt) train_op = tf.contrib.layers.optimize_loss(loss=loss, global_step=global_step, learning_rate=None, optimizer=opt) # Train/eval/test routine def _is_head(): if not FLAGS.distributed: return True else: return hvd.rank() == 0 def _train_epoch(sess): """Trains on the training set, and evaluates on the dev set periodically. """ iterator.restart_dataset(sess, 'train') fetches = { 'train_op': train_op, 'loss': loss, 'batch_size': batch_size, 'step': global_step } while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } rets = sess.run(fetches, feed_dict) step = rets['step'] dis_steps = config_data.display_steps if _is_head() and dis_steps > 0 and step % dis_steps == 0: tf.logging.info('step:%d; loss:%f' % (step, rets['loss'])) eval_steps = config_data.eval_steps if _is_head() and eval_steps > 0 and step % eval_steps == 0: _eval_epoch(sess) except tf.errors.OutOfRangeError: break def _eval_epoch(sess): """Evaluates on the dev set. """ iterator.restart_dataset(sess, 'eval') cum_acc = 0.0 cum_loss = 0.0 nsamples = 0 fetches = { 'accu': accu, 'loss': loss, 'batch_size': batch_size, } while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'eval'), tx.context.global_mode(): tf.estimator.ModeKeys.EVAL, } rets = sess.run(fetches, feed_dict) cum_acc += rets['accu'] * rets['batch_size'] cum_loss += rets['loss'] * rets['batch_size'] nsamples += rets['batch_size'] except tf.errors.OutOfRangeError: break tf.logging.info('eval accu: {}; loss: {}; nsamples: {}'.format( cum_acc / nsamples, cum_loss / nsamples, nsamples)) def _test_epoch(sess): """Does predictions on the test set. """ iterator.restart_dataset(sess, 'test') _all_preds = [] while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'test'), tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT, } _preds = sess.run(preds, feed_dict=feed_dict) _all_preds.extend(_preds.tolist()) except tf.errors.OutOfRangeError: break output_file = os.path.join(FLAGS.output_dir, "test_results.tsv") with tf.gfile.GFile(output_file, "w") as writer: writer.write('\n'.join(str(p) for p in _all_preds)) # Loads pretrained BERT model parameters init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt') model_utils.init_bert_checkpoint(init_checkpoint) # Broadcasts global variables from rank-0 process if FLAGS.distributed: bcast = hvd.broadcast_global_variables(0) session_config = tf.ConfigProto() if FLAGS.distributed: session_config.gpu_options.visible_device_list = str(hvd.local_rank()) with tf.Session(config=session_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) if FLAGS.distributed: bcast.run() # Restores trained model if specified saver = tf.train.Saver() if FLAGS.checkpoint: saver.restore(sess, FLAGS.checkpoint) iterator.initialize_dataset(sess) if FLAGS.do_train: for i in range(config_data.max_train_epoch): _train_epoch(sess) saver.save(sess, FLAGS.output_dir + '/model.ckpt') if FLAGS.do_eval: _eval_epoch(sess) if FLAGS.do_test: _test_epoch(sess)