def get_model(ctx): """get model""" # model pretrained = args.pretrained dataset = args.dataset_name model, vocabulary = bert_12_768_12(dataset_name=dataset, pretrained=pretrained, ctx=ctx) if not pretrained: model.initialize(init=mx.init.Normal(0.02), ctx=ctx) if args.ckpt_dir and args.start_step: param_path = os.path.join(args.ckpt_dir, '%07d.params' % args.start_step) model.load_parameters(param_path, ctx=ctx) logging.info('Loading step %d checkpoints from %s.', args.start_step, param_path) model.cast(args.dtype) model.hybridize(static_alloc=True) # losses nsp_loss = gluon.loss.SoftmaxCELoss() mlm_loss = gluon.loss.SoftmaxCELoss() nsp_loss.hybridize(static_alloc=True) mlm_loss.hybridize(static_alloc=True) return model, nsp_loss, mlm_loss, vocabulary
accumulate = args.accumulate log_interval = args.log_interval * accumulate if accumulate else args.log_interval if accumulate: logging.info('Using gradient accumulation. Effective batch size = %d', accumulate*batch_size) # random seed np.random.seed(args.seed) random.seed(args.seed) mx.random.seed(args.seed) ctx = mx.cpu() if not args.gpu else mx.gpu() # model and loss dataset = 'book_corpus_wiki_en_uncased' bert, vocabulary = bert_12_768_12(dataset_name=dataset, pretrained=True, ctx=ctx, use_pooler=True, use_decoder=False, use_classifier=False) model = BERTClassifier(bert, dropout=0.1) model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx) model.hybridize(static_alloc=True) loss_function = gluon.loss.SoftmaxCELoss() loss_function.hybridize(static_alloc=True) metric = mx.metric.Accuracy() # data processing do_lower_case = 'uncased' in dataset bert_tokenizer = FullTokenizer(vocabulary, do_lower_case=do_lower_case) def preprocess_data(tokenizer, batch_size, dev_batch_size, max_len): """Data preparation function."""
default=128, help='Maximum length of the sentence pairs') parser.add_argument('--gpu', action='store_true', help='whether to use gpu for finetuning') args = parser.parse_args() logging.info(args) batch_size = args.batch_size test_batch_size = args.test_batch_size lr = args.lr ctx = mx.cpu() if args.gpu is None or args.gpu == '' else mx.gpu() bert, vocabulary = bert_12_768_12(dataset_name='book_corpus_wiki_en_uncased', pretrained=True, ctx=ctx, use_pooler=True, use_decoder=False, use_classifier=False) tokenizer = FullTokenizer(vocabulary, do_lower_case=True) model = BERTClassifier(bert, dropout=0.1) model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx) model.hybridize(static_alloc=True) logging.info(model) loss_function = gluon.loss.SoftmaxCELoss() loss_function.hybridize(static_alloc=True) metric = mx.metric.Accuracy() trans = ClassificationTransform(tokenizer, MRPCDataset.get_labels(),