def build_model(self, args, model_args, ctx, dataset=None, vocab=None): dataset = model_args.model_name if model_args.model_type == 'bert': model_name = 'bert_12_768_12' elif model_args.model_type == 'bertl': model_name = 'bert_24_1024_16' elif model_args.model_type == 'roberta': model_name = 'roberta_12_768_12' elif model_args.model_type == 'robertal': model_name = 'roberta_24_1024_16' else: raise NotImplementedError self.is_roberta = model_args.model_type.startswith('roberta') if args.model_params is None: pretrained = True else: pretrained = False bert, vocabulary = nlp.model.get_model( name=model_name, dataset_name=dataset, pretrained=pretrained, ctx=ctx, use_pooler=False if self.is_roberta else True, use_decoder=False, use_classifier=False) if args.model_params: bert.load_parameters(args.model_params, ctx=ctx, cast_dtype=True, ignore_extra=True) if args.fix_bert_weights: bert.collect_params('.*weight|.*bias').setattr('grad_req', 'null') if vocab: vocabulary = vocab do_lower_case = 'uncased' in dataset task_name = args.task_name num_classes = self.task.num_classes() if self.is_roberta: model = RoBERTaClassifier(bert, dropout=0.0, num_classes=num_classes) self.tokenizer = nlp.data.GPT2BPETokenizer() else: model = BERTClassifier(bert, num_classes=num_classes, dropout=model_args.dropout) self.tokenizer = BERTTokenizer(vocabulary, lower=do_lower_case) return model, vocabulary
bert, vocabulary = nlp.model.get_model(**get_model_params) # initialize the rest of the parameters initializer = mx.init.Normal(0.02) # STS-B is a regression task. # STSBTask().class_labels returns None do_regression = not task.class_labels if do_regression: num_classes = 1 loss_function = gluon.loss.L2Loss() else: num_classes = len(task.class_labels) loss_function = gluon.loss.SoftmaxCELoss() # reuse the BERTClassifier class with num_classes=1 for regression if use_roberta: model = RoBERTaClassifier(bert, dropout=0.0, num_classes=num_classes) else: model = BERTClassifier(bert, dropout=0.1, num_classes=num_classes) # initialize classifier if not model_parameters: model.classifier.initialize(init=initializer, ctx=ctx) # load checkpointing output_dir = args.output_dir if pretrained_bert_parameters: logging.info('loading bert params from %s', pretrained_bert_parameters) nlp.utils.load_parameters(model.bert, pretrained_bert_parameters, ctx=ctx, ignore_extra=True, cast_dtype=True)