Beispiel #1
0
def get_model(ctx):
    """get model"""
    # model
    pretrained = args.pretrained
    dataset = args.dataset_name
    model, vocabulary = bert_12_768_12(dataset_name=dataset,
                                       pretrained=pretrained,
                                       ctx=ctx)
    if not pretrained:
        model.initialize(init=mx.init.Normal(0.02), ctx=ctx)

    if args.ckpt_dir and args.start_step:
        param_path = os.path.join(args.ckpt_dir,
                                  '%07d.params' % args.start_step)
        model.load_parameters(param_path, ctx=ctx)
        logging.info('Loading step %d checkpoints from %s.', args.start_step,
                     param_path)

    model.cast(args.dtype)
    model.hybridize(static_alloc=True)

    # losses
    nsp_loss = gluon.loss.SoftmaxCELoss()
    mlm_loss = gluon.loss.SoftmaxCELoss()
    nsp_loss.hybridize(static_alloc=True)
    mlm_loss.hybridize(static_alloc=True)

    return model, nsp_loss, mlm_loss, vocabulary
Beispiel #2
0
accumulate = args.accumulate
log_interval = args.log_interval * accumulate if accumulate else args.log_interval
if accumulate:
    logging.info('Using gradient accumulation. Effective batch size = %d', accumulate*batch_size)

# random seed
np.random.seed(args.seed)
random.seed(args.seed)
mx.random.seed(args.seed)

ctx = mx.cpu() if not args.gpu else mx.gpu()

# model and loss
dataset = 'book_corpus_wiki_en_uncased'
bert, vocabulary = bert_12_768_12(dataset_name=dataset,
                                  pretrained=True, ctx=ctx, use_pooler=True,
                                  use_decoder=False, use_classifier=False)
model = BERTClassifier(bert, dropout=0.1)
model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
model.hybridize(static_alloc=True)

loss_function = gluon.loss.SoftmaxCELoss()
loss_function.hybridize(static_alloc=True)
metric = mx.metric.Accuracy()

# data processing
do_lower_case = 'uncased' in dataset
bert_tokenizer = FullTokenizer(vocabulary, do_lower_case=do_lower_case)

def preprocess_data(tokenizer, batch_size, dev_batch_size, max_len):
    """Data preparation function."""
                    default=128,
                    help='Maximum length of the sentence pairs')
parser.add_argument('--gpu',
                    action='store_true',
                    help='whether to use gpu for finetuning')
args = parser.parse_args()
logging.info(args)
batch_size = args.batch_size
test_batch_size = args.test_batch_size
lr = args.lr

ctx = mx.cpu() if args.gpu is None or args.gpu == '' else mx.gpu()

bert, vocabulary = bert_12_768_12(dataset_name='book_corpus_wiki_en_uncased',
                                  pretrained=True,
                                  ctx=ctx,
                                  use_pooler=True,
                                  use_decoder=False,
                                  use_classifier=False)
tokenizer = FullTokenizer(vocabulary, do_lower_case=True)

model = BERTClassifier(bert, dropout=0.1)
model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
model.hybridize(static_alloc=True)
logging.info(model)

loss_function = gluon.loss.SoftmaxCELoss()
loss_function.hybridize(static_alloc=True)

metric = mx.metric.Accuracy()

trans = ClassificationTransform(tokenizer, MRPCDataset.get_labels(),