Beispiel #1
0
def main(_):
    """
    Builds the model and runs.
    """
    tf.logging.set_verbosity(tf.logging.INFO)
    tx.utils.maybe_create_dir(FLAGS.output_dir)
    bert_pretrain_dir = 'bert_pretrained_models/%s' % FLAGS.config_bert_pretrain

    # Loads BERT model configuration
    if FLAGS.config_format_bert == "json":
        bert_config = model_utils.transform_bert_to_texar_config(
            os.path.join(bert_pretrain_dir, 'bert_config.json'))
    elif FLAGS.config_format_bert == 'texar':
        bert_config = importlib.import_module(
            'bert_config_lib.config_model_%s' % FLAGS.config_bert_pretrain)
    else:
        raise ValueError('Unknown config_format_bert.')

    # Loads data
    processors = {
        "cola": data_utils.ColaProcessor,
        "mnli": data_utils.MnliProcessor,
        "mrpc": data_utils.MrpcProcessor,
        "xnli": data_utils.XnliProcessor,
        'sst': data_utils.SSTProcessor
    }

    processor = processors[FLAGS.task.lower()]()

    num_classes = len(processor.get_labels())
    num_train_data = len(processor.get_train_examples(config_data.data_dir))

    tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(
        bert_pretrain_dir, 'vocab.txt'),
                                           do_lower_case=FLAGS.do_lower_case)

    train_dataset = data_utils.get_dataset(processor,
                                           tokenizer,
                                           config_data.data_dir,
                                           config_data.max_seq_length,
                                           config_data.train_batch_size,
                                           mode='train',
                                           output_dir=FLAGS.output_dir)
    eval_dataset = data_utils.get_dataset(processor,
                                          tokenizer,
                                          config_data.data_dir,
                                          config_data.max_seq_length,
                                          config_data.eval_batch_size,
                                          mode='eval',
                                          output_dir=FLAGS.output_dir)
    test_dataset = data_utils.get_dataset(processor,
                                          tokenizer,
                                          config_data.data_dir,
                                          config_data.max_seq_length,
                                          config_data.test_batch_size,
                                          mode='test',
                                          output_dir=FLAGS.output_dir)

    iterator = tx.data.FeedableDataIterator({
        'train': train_dataset,
        'eval': eval_dataset,
        'test': test_dataset
    })
    batch = iterator.get_next()
    input_ids = batch["input_ids"]
    segment_ids = batch["segment_ids"]
    batch_size = tf.shape(input_ids)[0]
    input_length = tf.reduce_sum(1 - tf.to_int32(tf.equal(input_ids, 0)),
                                 axis=1)

    # Builds BERT
    with tf.variable_scope('bert'):
        embedder = tx.modules.WordEmbedder(vocab_size=bert_config.vocab_size,
                                           hparams=bert_config.embed)
        word_embeds = embedder(input_ids)

        # Creates segment embeddings for each type of tokens.
        segment_embedder = tx.modules.WordEmbedder(
            vocab_size=bert_config.type_vocab_size,
            hparams=bert_config.segment_embed)
        segment_embeds = segment_embedder(segment_ids)

        input_embeds = word_embeds + segment_embeds

        # The BERT model (a TransformerEncoder)
        encoder = tx.modules.TransformerEncoder(hparams=bert_config.encoder)
        output = encoder(input_embeds, input_length)

        # Builds layers for downstream classification, which is also initialized
        # with BERT pre-trained checkpoint.
        with tf.variable_scope("pooler"):
            # Uses the projection of the 1st-step hidden vector of BERT output
            # as the representation of the sentence
            bert_sent_hidden = tf.squeeze(output[:, 0:1, :], axis=1)
            bert_sent_output = tf.layers.dense(bert_sent_hidden,
                                               config_downstream.hidden_dim,
                                               activation=tf.tanh)
            output = tf.layers.dropout(bert_sent_output,
                                       rate=0.1,
                                       training=tx.global_mode_train())

    # Adds the final classification layer
    logits = tf.layers.dense(
        output,
        num_classes,
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
    preds = tf.argmax(logits, axis=-1, output_type=tf.int32)
    accu = tx.evals.accuracy(batch['label_ids'], preds)

    # Optimization

    loss = tf.losses.sparse_softmax_cross_entropy(labels=batch["label_ids"],
                                                  logits=logits)
    global_step = tf.Variable(0, trainable=False)

    # Builds learning rate decay scheduler
    static_lr = config_downstream.lr['static_lr']
    num_train_steps = int(num_train_data / config_data.train_batch_size *
                          config_data.max_train_epoch)
    num_warmup_steps = int(num_train_steps * config_data.warmup_proportion)
    lr = model_utils.get_lr(
        global_step,
        num_train_steps,  # lr is a Tensor
        num_warmup_steps,
        static_lr)

    train_op = tx.core.get_train_op(loss,
                                    global_step=global_step,
                                    learning_rate=lr,
                                    hparams=config_downstream.opt)

    # Train/eval/test routine

    def _run(sess, mode):
        fetches = {
            'accu': accu,
            'batch_size': batch_size,
            'step': global_step,
            'loss': loss,
        }

        if mode == 'train':
            fetches['train_op'] = train_op
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, 'train'),
                        tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                    }
                    rets = sess.run(fetches, feed_dict)
                    if rets['step'] % 50 == 0:
                        tf.logging.info('step:%d loss:%f' %
                                        (rets['step'], rets['loss']))
                    if rets['step'] == num_train_steps:
                        break
                except tf.errors.OutOfRangeError:
                    break

        if mode == 'eval':
            cum_acc = 0.0
            nsamples = 0
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, 'eval'),
                        tx.context.global_mode(): tf.estimator.ModeKeys.EVAL,
                    }
                    rets = sess.run(fetches, feed_dict)

                    cum_acc += rets['accu'] * rets['batch_size']
                    nsamples += rets['batch_size']
                except tf.errors.OutOfRangeError:
                    break

            tf.logging.info('dev accu: {}'.format(cum_acc / nsamples))

        if mode == 'test':
            _all_preds = []
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, 'test'),
                        tx.context.global_mode():
                        tf.estimator.ModeKeys.PREDICT,
                    }
                    _preds = sess.run(preds, feed_dict=feed_dict)
                    _all_preds.extend(_preds.tolist())
                except tf.errors.OutOfRangeError:
                    break

            output_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
            with tf.gfile.GFile(output_file, "w") as writer:
                writer.write('\n'.join(str(p) for p in _all_preds))

    with tf.Session() as sess:
        # Loads pretrained BERT model parameters
        init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt')
        model_utils.init_bert_checkpoint(init_checkpoint)

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        # Restores trained model if specified
        saver = tf.train.Saver()
        if FLAGS.checkpoint:
            saver.restore(sess, FLAGS.checkpoint)

        iterator.initialize_dataset(sess)

        if FLAGS.do_train:
            iterator.restart_dataset(sess, 'train')
            _run(sess, mode='train')
            saver.save(sess, FLAGS.output_dir + '/model.ckpt')

        if FLAGS.do_eval:
            iterator.restart_dataset(sess, 'eval')
            _run(sess, mode='eval')

        if FLAGS.do_test:
            iterator.restart_dataset(sess, 'test')
            _run(sess, mode='test')
Beispiel #2
0
def main():
    """
    Builds the model and runs.
    """

    tx.utils.maybe_create_dir(args.output_dir)

    # Loads data
    num_train_data = config_data.num_train_data

    # Builds BERT
    bert_pretrain_dir = f'bert_pretrained_models/{args.config_bert_pretrain}'
    if args.config_format_bert == "json":
        bert_config = model_utils.transform_bert_to_texar_config(
            os.path.join(bert_pretrain_dir, 'bert_config.json'))
    elif args.config_format_bert == 'texar':
        bert_config = importlib.import_module(
            f'bert_config_lib.config_model_{args.config_bert_pretrain}')
    else:
        raise ValueError('Unknown config_format_bert.')

    bert_hparams = BertClassifier.default_hparams()
    for key in bert_config.keys():
        bert_hparams[key] = bert_config[key]
    for key in config_downstream.keys():
        bert_hparams[key] = config_downstream[key]

    model = BertClassifier(hparams=bert_hparams)
    init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt')
    model_utils.init_bert_checkpoint(model, init_checkpoint)
    if torch.cuda.is_available():
        model = model.cuda()
    print(f"Pretrained model loaded from {init_checkpoint}")

    # Builds learning rate decay scheduler
    static_lr = 2e-5

    num_train_steps = int(num_train_data / config_data.train_batch_size *
                          config_data.max_train_epoch)
    num_warmup_steps = int(num_train_steps * config_data.warmup_proportion)

    vars_with_decay = []
    vars_without_decay = []
    for name, param in model.named_parameters():
        if 'layer_norm' in name or name.endswith('bias'):
            vars_without_decay.append(param)
        else:
            vars_with_decay.append(param)

    opt_params = [{
        'params': vars_with_decay,
        'weight_decay': 0.01,
    }, {
        'params': vars_without_decay,
        'weight_decay': 0.0,
    }]
    optim = BertAdam(opt_params, betas=(0.9, 0.999), eps=1e-6, lr=static_lr)

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optim,
        functools.partial(model_utils.get_lr_multiplier,
                          total_steps=num_train_steps,
                          warmup_steps=num_warmup_steps))

    train_dataset = tx.data.RecordData(hparams=config_data.train_hparam)
    eval_dataset = tx.data.RecordData(hparams=config_data.eval_hparam)
    test_dataset = tx.data.RecordData(hparams=config_data.test_hparam)

    iterator = tx.data.DataIterator({
        "train": train_dataset,
        "eval": eval_dataset,
        "test": test_dataset
    })

    def _train_epoch():
        """Trains on the training set, and evaluates on the dev set
        periodically.
        """
        iterator.switch_to_dataset("train")

        model.train()
        for batch in iterator:
            optim.zero_grad()
            input_ids = batch["input_ids"]
            segment_ids = batch["segment_ids"]
            labels = batch["label_ids"]

            if torch.cuda.is_available():
                input_ids = input_ids.cuda()
                segment_ids = segment_ids.cuda()
                labels = labels.cuda()

            input_length = (1 - (input_ids == 0).int()).sum(dim=1)

            _, _, loss = model(inputs=input_ids,
                               sequence_length=input_length,
                               segment_ids=segment_ids,
                               labels=labels)
            loss.backward()
            optim.step()
            scheduler.step()
            step = scheduler.last_epoch

            dis_steps = config_data.display_steps
            if dis_steps > 0 and step % dis_steps == 0:
                logging.info("step: %d; loss: %d", step, loss)

            eval_steps = config_data.eval_steps
            if eval_steps > 0 and step % eval_steps == 0:
                _eval_epoch()

    @torch.no_grad()
    def _eval_epoch():
        """Evaluates on the dev set.
        """
        iterator.switch_to_dataset("eval")
        model.eval()
        cum_acc = 0.0
        cum_loss = 0.0
        nsamples = 0
        for batch in iterator:
            input_ids = batch["input_ids"]
            segment_ids = batch["segment_ids"]
            labels = batch["label_ids"]

            if torch.cuda.is_available():
                input_ids = input_ids.cuda()
                segment_ids = segment_ids.cuda()
                labels = labels.cuda()

            batch_size = input_ids.size()[0]
            input_length = (1 - (input_ids == 0).int()).sum(dim=1)

            _, preds, loss = model(
                inputs=input_ids,
                sequence_length=input_length,
                segment_ids=segment_ids,
                labels=labels,
            )

            accu = tx.evals.accuracy(labels, preds)
            cum_acc += accu * batch_size
            cum_loss += loss * batch_size
            nsamples += batch_size
        logging.info("eval accu: %.4f; loss: %.4f; nsamples: %d",
                     cum_acc / nsamples, cum_loss / nsamples, nsamples)

    @torch.no_grad()
    def _test_epoch():
        """Does predictions on the test set.
        """
        iterator.switch_to_dataset("test")
        model.eval()
        _all_preds = []
        for batch in iterator:
            input_ids = batch["input_ids"]
            segment_ids = batch["segment_ids"]
            if torch.cuda.is_available():
                input_ids = input_ids.cuda()
                segment_ids = segment_ids.cuda()

            input_length = (1 - (input_ids == 0).int()).sum(dim=1)

            _, preds, _ = model(
                inputs=input_ids,
                sequence_length=input_length,
                segment_ids=segment_ids,
            )
            _all_preds.extend(preds.tolist())

        output_file = os.path.join(args.output_dir, "test_results.tsv")
        with open(output_file, "w+") as writer:
            writer.write("\n".join(str(p) for p in _all_preds))

    if args.checkpoint:
        ckpt = torch.load(args.checkpoint)
        model.load_state_dict(ckpt['model'])
        optim.load_state_dict(ckpt['optimizer'])
        scheduler.load_state_dict(ckpt['scheduler'])

    if args.do_train:
        for _ in range(config_data.max_train_epoch):
            _train_epoch()
        states = {
            'model': model.state_dict(),
            'optimizer': optim.state_dict(),
            'scheduler': scheduler.state_dict(),
        }
        torch.save(states, os.path.join(args.output_dir + '/model.ckpt'))

    if args.do_eval:
        _eval_epoch()

    if args.do_test:
        _test_epoch()
Beispiel #3
0
def main(_):
    """
    Builds the model and runs.
    """

    if FLAGS.distributed:
        import horovod.tensorflow as hvd
        hvd.init()

    tf.logging.set_verbosity(tf.logging.INFO)

    tx.utils.maybe_create_dir(FLAGS.output_dir)

    bert_pretrain_dir = ('bert_pretrained_models'
                         '/%s') % FLAGS.config_bert_pretrain
    # Loads BERT model configuration
    if FLAGS.config_format_bert == "json":
        bert_config = model_utils.transform_bert_to_texar_config(
            os.path.join(bert_pretrain_dir, 'bert_config.json'))
    elif FLAGS.config_format_bert == 'texar':
        bert_config = importlib.import_module(
            ('bert_config_lib.'
             'config_model_%s') % FLAGS.config_bert_pretrain)
    else:
        raise ValueError('Unknown config_format_bert.')

    # Loads data

    num_classes = config_data.num_classes
    num_train_data = config_data.num_train_data

    # Configures distribued mode
    if FLAGS.distributed:
        config_data.train_hparam["dataset"]["num_shards"] = hvd.size()
        config_data.train_hparam["dataset"]["shard_id"] = hvd.rank()
        config_data.train_hparam["batch_size"] //= hvd.size()

    train_dataset = tx.data.TFRecordData(hparams=config_data.train_hparam)
    eval_dataset = tx.data.TFRecordData(hparams=config_data.eval_hparam)
    test_dataset = tx.data.TFRecordData(hparams=config_data.test_hparam)

    iterator = tx.data.FeedableDataIterator({
        'train': train_dataset,
        'eval': eval_dataset,
        'test': test_dataset
    })
    batch = iterator.get_next()
    input_ids = batch["input_ids"]
    segment_ids = batch["segment_ids"]
    batch_size = tf.shape(input_ids)[0]
    input_length = tf.reduce_sum(1 - tf.cast(tf.equal(input_ids, 0), tf.int32),
                                 axis=1)
    # Builds BERT
    with tf.variable_scope('bert'):
        # Word embedding
        embedder = tx.modules.WordEmbedder(vocab_size=bert_config.vocab_size,
                                           hparams=bert_config.embed)
        word_embeds = embedder(input_ids)

        # Segment embedding for each type of tokens
        segment_embedder = tx.modules.WordEmbedder(
            vocab_size=bert_config.type_vocab_size,
            hparams=bert_config.segment_embed)
        segment_embeds = segment_embedder(segment_ids)

        # Position embedding
        position_embedder = tx.modules.PositionEmbedder(
            position_size=bert_config.position_size,
            hparams=bert_config.position_embed)
        seq_length = tf.ones([batch_size], tf.int32) * tf.shape(input_ids)[1]
        pos_embeds = position_embedder(sequence_length=seq_length)

        # Aggregates embeddings
        input_embeds = word_embeds + segment_embeds + pos_embeds

        # The BERT model (a TransformerEncoder)
        encoder = tx.modules.TransformerEncoder(hparams=bert_config.encoder)
        output = encoder(input_embeds, input_length)

        # Builds layers for downstream classification, which is also
        # initialized with BERT pre-trained checkpoint.
        with tf.variable_scope("pooler"):
            # Uses the projection of the 1st-step hidden vector of BERT output
            # as the representation of the sentence
            bert_sent_hidden = tf.squeeze(output[:, 0:1, :], axis=1)
            bert_sent_output = tf.layers.dense(bert_sent_hidden,
                                               config_downstream.hidden_dim,
                                               activation=tf.tanh)
            output = tf.layers.dropout(bert_sent_output,
                                       rate=0.1,
                                       training=tx.global_mode_train())

    # Adds the final classification layer
    logits = tf.layers.dense(
        output,
        num_classes,
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
    preds = tf.argmax(logits, axis=-1, output_type=tf.int32)
    accu = tx.evals.accuracy(batch['label_ids'], preds)

    # Optimization

    loss = tf.losses.sparse_softmax_cross_entropy(labels=batch["label_ids"],
                                                  logits=logits)
    global_step = tf.Variable(0, trainable=False)

    # Builds learning rate decay scheduler
    static_lr = config_downstream.lr['static_lr']
    num_train_steps = int(num_train_data / config_data.train_batch_size *
                          config_data.max_train_epoch)
    num_warmup_steps = int(num_train_steps * config_data.warmup_proportion)
    lr = model_utils.get_lr(
        global_step,
        num_train_steps,  # lr is a Tensor
        num_warmup_steps,
        static_lr)

    opt = tx.core.get_optimizer(global_step=global_step,
                                learning_rate=lr,
                                hparams=config_downstream.opt)

    if FLAGS.distributed:
        opt = hvd.DistributedOptimizer(opt)

    train_op = tf.contrib.layers.optimize_loss(loss=loss,
                                               global_step=global_step,
                                               learning_rate=None,
                                               optimizer=opt)

    # Train/eval/test routine

    def _is_head():
        if not FLAGS.distributed:
            return True
        else:
            return hvd.rank() == 0

    def _train_epoch(sess):
        """Trains on the training set, and evaluates on the dev set
        periodically.
        """
        iterator.restart_dataset(sess, 'train')

        fetches = {
            'train_op': train_op,
            'loss': loss,
            'batch_size': batch_size,
            'step': global_step
        }

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train'),
                    tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                }
                rets = sess.run(fetches, feed_dict)
                step = rets['step']

                dis_steps = config_data.display_steps
                if _is_head() and dis_steps > 0 and step % dis_steps == 0:
                    tf.logging.info('step:%d; loss:%f' % (step, rets['loss']))

                eval_steps = config_data.eval_steps
                if _is_head() and eval_steps > 0 and step % eval_steps == 0:
                    _eval_epoch(sess)

            except tf.errors.OutOfRangeError:
                break

    def _eval_epoch(sess):
        """Evaluates on the dev set.
        """
        iterator.restart_dataset(sess, 'eval')

        cum_acc = 0.0
        cum_loss = 0.0
        nsamples = 0
        fetches = {
            'accu': accu,
            'loss': loss,
            'batch_size': batch_size,
        }
        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'eval'),
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL,
                }
                rets = sess.run(fetches, feed_dict)

                cum_acc += rets['accu'] * rets['batch_size']
                cum_loss += rets['loss'] * rets['batch_size']
                nsamples += rets['batch_size']
            except tf.errors.OutOfRangeError:
                break

        tf.logging.info('eval accu: {}; loss: {}; nsamples: {}'.format(
            cum_acc / nsamples, cum_loss / nsamples, nsamples))

    def _test_epoch(sess):
        """Does predictions on the test set.
        """
        iterator.restart_dataset(sess, 'test')

        _all_preds = []
        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'test'),
                    tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }
                _preds = sess.run(preds, feed_dict=feed_dict)
                _all_preds.extend(_preds.tolist())
            except tf.errors.OutOfRangeError:
                break

        output_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
        with tf.gfile.GFile(output_file, "w") as writer:
            writer.write('\n'.join(str(p) for p in _all_preds))

    # Loads pretrained BERT model parameters
    init_checkpoint = os.path.join(bert_pretrain_dir, 'bert_model.ckpt')
    model_utils.init_bert_checkpoint(init_checkpoint)

    # Broadcasts global variables from rank-0 process
    if FLAGS.distributed:
        bcast = hvd.broadcast_global_variables(0)

    session_config = tf.ConfigProto()
    if FLAGS.distributed:
        session_config.gpu_options.visible_device_list = str(hvd.local_rank())

    with tf.Session(config=session_config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        if FLAGS.distributed:
            bcast.run()

        # Restores trained model if specified
        saver = tf.train.Saver()
        if FLAGS.checkpoint:
            saver.restore(sess, FLAGS.checkpoint)

        iterator.initialize_dataset(sess)

        if FLAGS.do_train:
            for i in range(config_data.max_train_epoch):
                _train_epoch(sess)
            saver.save(sess, FLAGS.output_dir + '/model.ckpt')

        if FLAGS.do_eval:
            _eval_epoch(sess)

        if FLAGS.do_test:
            _test_epoch(sess)