def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(3, ) input_ids[:, -1] = self.eos_token_id # Eos Token decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) config = M2M100Config( vocab_size=self.vocab_size, d_model=self.hidden_size, encoder_layers=self.num_hidden_layers, decoder_layers=self.num_hidden_layers, encoder_attention_heads=self.num_attention_heads, decoder_attention_heads=self.num_attention_heads, encoder_ffn_dim=self.intermediate_size, decoder_ffn_dim=self.intermediate_size, dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, ) inputs_dict = prepare_m2m_100_inputs_dict(config, input_ids, decoder_input_ids) return config, inputs_dict
def convert_fairseq_m2m100_checkpoint_from_disk(checkpoint_path): m2m_100 = torch.load(checkpoint_path, map_location="cpu") args = m2m_100["args"] state_dict = m2m_100["model"] remove_ignore_keys_(state_dict) vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] config = M2M100Config( vocab_size=vocab_size, max_position_embeddings=1024, encoder_layers=args.encoder_layers, decoder_layers=args.decoder_layers, encoder_attention_heads=args.encoder_attention_heads, decoder_attention_heads=args.decoder_attention_heads, encoder_ffn_dim=args.encoder_ffn_embed_dim, decoder_ffn_dim=args.decoder_ffn_embed_dim, d_model=args.encoder_embed_dim, encoder_layerdrop=args.encoder_layerdrop, decoder_layerdrop=args.decoder_layerdrop, dropout=args.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_function="relu", ) state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] model = M2M100ForConditionalGeneration(config) model.model.load_state_dict(state_dict) model.lm_head = make_linear_from_emb(model.model.shared) return model
def get_config(self): return M2M100Config( vocab_size=self.vocab_size, d_model=self.hidden_size, encoder_layers=self.num_hidden_layers, decoder_layers=self.num_hidden_layers, encoder_attention_heads=self.num_attention_heads, decoder_attention_heads=self.num_attention_heads, encoder_ffn_dim=self.intermediate_size, decoder_ffn_dim=self.intermediate_size, dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, encoder_layerdrop=self.encoder_layerdrop, decoder_layerdrop=self.decoder_layerdrop, max_position_embeddings=self.max_position_embeddings, eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, )
def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids[:, -1] = self.eos_token_id # Eos Token decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) # we need to clamp the input ids here to avoid having pad token in between # this is because for M2M100 the position_ids are prepared such that # all pad tokens have pos id = 2 and rest are between 2..seq_length # and the seq_length here is seq_length - num_pad_tokens # but when using past, there is no way of knowing if the past input ids had # pad tokens in them, which results in incorrect seq_lenth and which in turn results in # position_ids being off by num_pad_tokens in past input input_ids = input_ids.clamp(self.pad_token_id + 1) decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1) config = M2M100Config( vocab_size=self.vocab_size, d_model=self.hidden_size, encoder_layers=self.num_hidden_layers, decoder_layers=self.num_hidden_layers, encoder_attention_heads=self.num_attention_heads, decoder_attention_heads=self.num_attention_heads, encoder_ffn_dim=self.intermediate_size, decoder_ffn_dim=self.intermediate_size, dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, encoder_layerdrop=self.encoder_layerdrop, decoder_layerdrop=self.decoder_layerdrop, max_position_embeddings=self.max_position_embeddings, eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, ) inputs_dict = prepare_m2m_100_inputs_dict(config, input_ids, decoder_input_ids) return config, inputs_dict
def load(args): print('loading model') device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") # Load M2M-100 model config = M2M100Config.from_pretrained("facebook/m2m100_418M") config.method = 1 m2m = M2M100ForConditionalGeneration.from_pretrained( "facebook/m2m100_418M", config=config) tokenizer = M2M100Tokenizer.from_pretrained('facebook/m2m100_418M') # Build Fused Model and load parameters from local checkpoint model = FusedM2M(config, None, m2m) state_dict = torch.load(args.checkpoint) state_dict = {k: v for k, v in state_dict.items() if 'fuse' in k} # load linear layer only model.load_state_dict(state_dict, strict=False) model = model.model # Take the M2M100Model from M2M100ForConditionalGeneration model.to(device) if args.num_gpus > 1: model = torch.nn.DataParallel(model) model.eval() return model, tokenizer, device
def filter_none(example): return example["translation"][source_lang] is not None and example[ "translation"][target_lang] is not None # Preprocess data or load from local file if args.load_local_dataset: tokenized_datasets = load_from_disk("data") else: tokenized_datasets = raw_datasets.filter(filter_none).map(preprocess, batched=True) tokenized_datasets.save_to_disk("data") # Prepare models config = M2M100Config.from_pretrained("facebook/m2m100_418M") config.method = fuse_method m2m = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", config=config) fused_model = FusedM2M(config, bert, m2m) # DEBUG: Check the weight of layers # shared_weight = m2m.model.shared.weight.data.clone().detach() # layer_1_weight = m2m.model.encoder.layers[0].fc1.weight.data.clone().detach() # fuse_12_weight = m2m.model.encoder.layers[-1].fuse_layer.weight.data.clone().detach() # Load state dict from local checkpoint if checkpoint: state_dict = torch.load(f'{checkpoint}/pytorch_model.bin') state_dict = {k: v for k, v in state_dict.items()