def get_pretraining_model(model_name, ctx_l, max_seq_length=128, hidden_dropout_prob=0.1, attention_dropout_prob=0.1, generator_units_scale=None, generator_layers_scale=None): """ A Electra Pretrain Model is built with a generator and a discriminator, in which the generator has the same embedding as the discriminator but different backbone. """ cfg, tokenizer, _, _ = get_pretrained_electra(model_name, load_backbone=False) cfg = ElectraModel.get_cfg().clone_merge(cfg) cfg.defrost() cfg.MODEL.hidden_dropout_prob = hidden_dropout_prob cfg.MODEL.attention_dropout_prob = attention_dropout_prob cfg.MODEL.max_length = max_seq_length # Keep the original generator size if not designated if generator_layers_scale: cfg.MODEL.generator_layers_scale = generator_layers_scale if generator_units_scale: cfg.MODEL.generator_units_scale = generator_units_scale cfg.freeze() model = ElectraForPretrain(cfg, uniform_generator=False, tied_generator=False, tied_embeddings=True, disallow_correct=False, weight_initializer=TruncNorm(stdev=0.02)) model.initialize(ctx=ctx_l) model.hybridize() return cfg, tokenizer, model
def get_network(model_name, ctx_l, dropout=0.1, checkpoint_path=None, backbone_path=None, dtype='float32'): """ Get the network that fine-tune the Question Answering Task Parameters ---------- model_name : str The model name of the backbone model ctx_l : Context list of training device like [mx.gpu(0), mx.gpu(1)] dropout : float Dropout probability of the task specified layer checkpoint_path: str Path to a Fine-tuned checkpoint backbone_path: str Path to the backbone model to be loaded in qa_net Returns ------- cfg tokenizer qa_net use_segmentation """ # Create the network use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name Model, cfg, tokenizer, download_params_path, _ = \ get_backbone(model_name, load_backbone=not backbone_path) backbone = Model.from_cfg(cfg, use_pooler=False, dtype=dtype) # Load local backbone parameters if backbone_path provided. # Otherwise, download backbone parameters from gluon zoo. backbone_params_path = backbone_path if backbone_path else download_params_path if checkpoint_path is None: backbone.load_parameters(backbone_params_path, ignore_extra=True, ctx=ctx_l, cast_dtype=True) num_params, num_fixed_params = count_parameters(backbone.collect_params()) logging.info( 'Loading Backbone Model from {}, with total/fixd parameters={}/{}'.format( backbone_params_path, num_params, num_fixed_params)) qa_net = ModelForQAConditionalV1(backbone=backbone, dropout_prob=dropout, use_segmentation=use_segmentation, weight_initializer=TruncNorm(stdev=0.02)) if checkpoint_path is None: # Ignore the UserWarning during initialization, # There is no need to re-initialize the parameters of backbone qa_net.initialize(ctx=ctx_l) else: qa_net.load_parameters(checkpoint_path, ctx=ctx_l, cast_dtype=True) qa_net.hybridize() return cfg, tokenizer, qa_net, use_segmentation