Esempio n. 1
0
def inference(metric):
    """inference function."""

    logging.info('|----- Now we are doing BERT inference at {} !'.format(ctx))
    model = BERTClassifier(bert, dropout=0.1, num_classes=len(task.get_labels()))
    para_name = 'MRPC_valid_best_{}.params'.format(args.max_len)
    model.load_parameters(os.path.join(args.save_dir, para_name), ctx=ctx)

    is_profiler_on = os.getenv('GLUONNLP_BERT_PROFILING', False)
    if is_profiler_on:
        mx.profiler.set_config(profile_symbolic=True, profile_imperative=True, profile_memory=False,
                               profile_api=False, filename='profile.json', aggregate_stats=True)
        mx.profiler.set_state('run')

    for epoch_id in range(1):
        metric.reset()
        step_loss = 0
        tic = time.time()

        for batch_id, seqs in enumerate(dev_data):
            input_ids, valid_length, type_ids, label = seqs
            out = model(
                input_ids.as_in_context(ctx), type_ids.as_in_context(ctx),
                valid_length.astype('float32').as_in_context(ctx))
            ls = loss_function(out, label.as_in_context(ctx)).mean()

            step_loss += ls.asscalar()
            metric.update([label], [out])

            if (batch_id + 1) % (args.log_interval) == 0:
                metric_nm, metric_val = metric.get()
                if not isinstance(metric_nm, list):
                    metric_nm = [metric_nm]
                    metric_val = [metric_val]
                eval_str = '[Epoch {} Batch {}/{}] loss={:.4f}, metrics=' + \
                    ','.join([i + ':{:.4f}' for i in metric_nm])
                logging.info(eval_str.format(epoch_id + 1, batch_id + 1, len(dev_data), \
                    step_loss / args.log_interval, \
                    *metric_val,))
                step_loss = 0

        mx.nd.waitall()

        toc = time.time()
        logging.info('Time cost={:.1f}s'.format(toc - tic))

    if is_profiler_on:
        mx.profiler.set_state('stop')
        print(mx.profiler.dumps())
def inference(metric):
    """Inference function."""

    logging.info('Now we are doing BERT classification inference on %s!', ctx)
    model = BERTClassifier(bert,
                           dropout=0.1,
                           num_classes=len(task.get_labels()))
    model.hybridize(static_alloc=True)
    model.load_parameters(model_parameters, ctx=ctx)

    metric.reset()
    step_loss = 0
    tic = time.time()
    for batch_id, seqs in enumerate(dev_data):
        input_ids, valid_length, type_ids, label = seqs
        out = model(input_ids.as_in_context(ctx), type_ids.as_in_context(ctx),
                    valid_length.astype('float32').as_in_context(ctx))

        ls = loss_function(out, label.as_in_context(ctx)).mean()

        step_loss += ls.asscalar()
        metric.update([label], [out])

        if (batch_id + 1) % (args.log_interval) == 0:
            log_inference(batch_id, len(dev_data), metric, step_loss,
                          args.log_interval)
            step_loss = 0

    mx.nd.waitall()
    toc = time.time()
    total_num = dev_batch_size * len(dev_data)
    logging.info('Time cost=%.2fs, throughput=%.2fsamples/s', toc - tic, \
                 total_num / (toc - tic))
Esempio n. 3
0
bert, vocabulary = get_bert_model(
    model_name=model_name,
    dataset_name=dataset,
    pretrained=get_pretrained,
    ctx=ctx,
    use_pooler=True,
    use_decoder=False,
    use_classifier=False)

if task.task_name in ['STS-B']:
    model = BERTRegression(bert, dropout=0.1)
    if not model_parameters:
        model.regression.initialize(init=mx.init.Normal(0.02), ctx=ctx)
    loss_function = gluon.loss.L2Loss()
else:
    model = BERTClassifier(
        bert, dropout=0.1, num_classes=len(task.get_labels()))
    if not model_parameters:
        model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
    loss_function = gluon.loss.SoftmaxCELoss()

# load checkpointing
output_dir = args.output_dir
if pretrained_bert_parameters:
    logging.info('loading bert params from {0}'.format(pretrained_bert_parameters))
    model.bert.load_parameters(pretrained_bert_parameters, ctx=ctx,
                               ignore_extra=True)
if model_parameters:
    logging.info('loading model params from {0}'.format(model_parameters))
    model.load_parameters(model_parameters, ctx=ctx)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
Esempio n. 4
0
if accumulate:
    logging.info('Using gradient accumulation. Effective batch size = %d', accumulate*batch_size)

# random seed
np.random.seed(args.seed)
random.seed(args.seed)
mx.random.seed(args.seed)

ctx = mx.cpu() if not args.gpu else mx.gpu()

# model and loss
dataset = 'book_corpus_wiki_en_uncased'
bert, vocabulary = bert_12_768_12(dataset_name=dataset,
                                  pretrained=True, ctx=ctx, use_pooler=True,
                                  use_decoder=False, use_classifier=False)
model = BERTClassifier(bert, dropout=0.1)
model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
model.hybridize(static_alloc=True)

loss_function = gluon.loss.SoftmaxCELoss()
loss_function.hybridize(static_alloc=True)
metric = mx.metric.Accuracy()

# data processing
do_lower_case = 'uncased' in dataset
bert_tokenizer = FullTokenizer(vocabulary, do_lower_case=do_lower_case)

def preprocess_data(tokenizer, batch_size, dev_batch_size, max_len):
    """Data preparation function."""
    # transformation
    train_trans = ClassificationTransform(tokenizer, MRPCDataset.get_labels(),