def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> EncoderOutput: """Pass inputs to encoder, return encoder output. Args: encoder: bare model outputting raw hidden-states without any specific head. input_ids: token indices (see huggingface.co/transformers/glossary.html#input-ids). segment_ids: token type ids (see huggingface.co/transformers/glossary.html#token-type-ids). input_mask: attention mask (see huggingface.co/transformers/glossary.html#attention-mask). Raises: RuntimeError if encoder output contains less than 2 elements. Returns: EncoderOutput containing pooled and unpooled model outputs as well as any other outputs. """ model_arch = ModelArchitectures.from_encoder(encoder) if model_arch in [ ModelArchitectures.BERT, ModelArchitectures.ROBERTA, ModelArchitectures.ALBERT, ModelArchitectures.XLM_ROBERTA, ]: pooled, unpooled, other = get_output_from_standard_transformer_models( encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, ) elif model_arch == ModelArchitectures.ELECTRA: pooled, unpooled, other = get_output_from_electra( encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, ) elif model_arch in [ ModelArchitectures.BART, ModelArchitectures.MBART, ]: pooled, unpooled, other = get_output_from_bart_models( encoder=encoder, input_ids=input_ids, input_mask=input_mask, ) elif model_arch == ModelArchitectures.DISTILBERT: pooled, unpooled, other = get_output_from_distilbert( encoder=encoder, input_ids=input_ids, input_mask=input_mask, ) else: raise KeyError(model_arch) # Extend later with attention, hidden_acts, etc if other: return EncoderOutput(pooled=pooled, unpooled=unpooled, other=other) else: return EncoderOutput(pooled=pooled, unpooled=unpooled)
def load_encoder_from_transformers_weights(encoder: nn.Module, weights_dict: dict, return_remainder=False): """Find encoder weights in weights dict, load them into encoder, return any remaining weights. TODO: clarify how we know the encoder weights will be prefixed by transformer name. Args: encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer). weights_dict (Dict): model weights. return_remainder (bool): If True, return any leftover weights. Returns: Dict containing any leftover weights. """ remainder_weights_dict = {} load_weights_dict = {} model_arch = ModelArchitectures.from_encoder(encoder=encoder) encoder_prefix = MODEL_PREFIX[model_arch] + "." # Encoder for k, v in weights_dict.items(): if k.startswith(encoder_prefix): load_weights_dict[strings.remove_prefix(k, encoder_prefix)] = v else: remainder_weights_dict[k] = v encoder.load_state_dict(load_weights_dict) if return_remainder: return remainder_weights_dict
def setup_jiant_model( model_type: str, model_config_path: str, tokenizer_path: str, task_dict: Dict[str, Task], taskmodels_config: container_setup.TaskmodelsConfig, ): """Sets up tokenizer, encoder, and task models, and instantiates and returns a JiantModel. Args: model_type (str): model shortcut name. model_config_path (str): Path to the JSON file containing the configuration parameters. tokenizer_path (str): path to tokenizer directory. task_dict (Dict[str, tasks.Task]): map from task name to task instance. taskmodels_config: maps mapping from tasks to models, and specifying task-model configs. Returns: JiantModel nn.Module. """ model_arch = ModelArchitectures.from_model_type(model_type) transformers_class_spec = TRANSFORMERS_CLASS_SPEC_DICT[model_arch] tokenizer = model_setup.get_tokenizer(model_type=model_type, tokenizer_path=tokenizer_path) ancestor_model = get_ancestor_model( transformers_class_spec=transformers_class_spec, model_config_path=model_config_path, ) encoder = get_encoder(model_arch=model_arch, ancestor_model=ancestor_model) taskmodels_dict = { taskmodel_name: create_taskmodel( task=task_dict[task_name_list[0]], # Take the first task model_arch=model_arch, encoder=encoder, taskmodel_kwargs=taskmodels_config.get_taskmodel_kwargs( taskmodel_name), ) for taskmodel_name, task_name_list in get_taskmodel_and_task_names( taskmodels_config.task_to_taskmodel_map).items() } return primary.JiantModel( task_dict=task_dict, encoder=encoder, taskmodels_dict=taskmodels_dict, task_to_taskmodel_map=taskmodels_config.task_to_taskmodel_map, tokenizer=tokenizer, )
def get_model_arch_from_jiant_model( jiant_model: nn.Module) -> ModelArchitectures: return ModelArchitectures.from_encoder(encoder=jiant_model.encoder)
def main(args: RunConfiguration): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # === Shared model components setup === # model_type = "roberta-base" model_arch = ModelArchitectures.from_model_type(model_type=model_type) transformers_class_spec = model_setup.TRANSFORMERS_CLASS_SPEC_DICT[ model_arch] ancestor_model = model_setup.get_ancestor_model( transformers_class_spec=transformers_class_spec, model_config_path=args.model_config_path, ) encoder = model_setup.get_encoder( model_arch=model_arch, ancestor_model=ancestor_model, ) tokenizer = shared_model_setup.get_tokenizer( model_type=model_type, tokenizer_path=args.model_tokenizer_path, ) # === Taskmodels setup === # task_dict = { "mnli": tasks.create_task_from_config_path( os.path.join( args.task_config_base_path, "mnli.json", )), "qnli": tasks.create_task_from_config_path( os.path.join( args.task_config_base_path, "qnli.json", )), "rte": tasks.create_task_from_config_path( os.path.join( args.task_config_base_path, "qnli.json", )) } taskmodels_dict = { "nli": taskmodels.ClassificationModel( encoder=encoder, classification_head=heads.ClassificationHead( hidden_size=encoder.config.hidden_size, hidden_dropout_prob=encoder.config.hidden_dropout_prob, num_labels=len(task_dict["mnli"].LABELS), ), ), "rte": taskmodels.ClassificationModel( encoder=encoder, classification_head=heads.ClassificationHead( hidden_size=encoder.config.hidden_size, hidden_dropout_prob=encoder.config.hidden_dropout_prob, num_labels=len(task_dict["rte"].LABELS), ), ), } task_to_taskmodel_map = { "mnli": "nli", "qnli": "nli", "rte": "rte", } # === Final === # jiant_model = JiantModel( task_dict=task_dict, encoder=encoder, taskmodels_dict=taskmodels_dict, task_to_taskmodel_map=task_to_taskmodel_map, tokenizer=tokenizer, ) jiant_model = jiant_model.to(device) # === Run === # task_dataloader_dict = {} for task_name, task in task_dict.items(): train_cache = caching.ChunkedFilesDataCache( cache_fol_path=os.path.join(args.task_cache_base_path, task_name, "train"), ) train_dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=True) train_dataloader = torch_utils.DataLoaderWithLength( dataset=train_dataset, batch_size=4, collate_fn=task.collate_fn, ) task_dataloader_dict[task_name] = train_dataloader for task_name, task in task_dict.items(): batch, batch_metadata = next(iter(task_dataloader_dict[task_name])) batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=True, ) print(task_name) print(model_output) print()