def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id ) self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def create_and_check_bert_encoder_decoder_model_generate( self, input_ids, config, decoder_config, **kwargs): encoder_model = BertModel(config) decoder_model = BertForMaskedLM(decoder_config) enc_dec_model = EncoderDecoderModel(encoder_model, decoder_model) # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( input_ids, decoder_start_token_id=enc_dec_model.config.pad_token_id) self.assertEqual(generated_output.shape, (input_ids.shape[0], ) + (decoder_config.max_length, ))
def inference(): step = sys.argv[1] encoder_config = BertConfig.from_pretrained("monologg/kobert") decoder_config = BertConfig.from_pretrained("monologg/kobert") config = EncoderDecoderConfig.from_encoder_decoder_configs( encoder_config, decoder_config) tokenizer = KoBertTokenizer() model = EncoderDecoderModel(config=config) ckpt = "model.pt" device = "cuda" model.load_state_dict( torch.load(f"saved/{ckpt}.{step}", map_location="cuda"), strict=True, ) model = model.half().eval().to(device) test_data = open("dataset/abstractive_test_v2.jsonl", "r").read().splitlines() submission = open(f"submission_{step}.csv", "w") test_set = [] for data in test_data: data = json.loads(data) article_original = data["article_original"] article_original = " ".join(article_original) news_id = data["id"] test_set.append((news_id, article_original)) for i, (news_id, text) in tqdm(enumerate(test_set)): tokens = tokenizer.encode_batch([text], max_length=512) generated = model.generate( input_ids=tokens["input_ids"].to(device), attention_mask=tokens["attention_mask"].to(device), use_cache=True, bos_token_id=tokenizer.token2idx["[CLS]"], eos_token_id=tokenizer.token2idx["[SEP]"], pad_token_id=tokenizer.token2idx["[PAD]"], num_beams=12, do_sample=False, temperature=1.0, no_repeat_ngram_size=3, bad_words_ids=[[tokenizer.token2idx["[UNK]"]]], length_penalty=1.0, repetition_penalty=1.5, max_length=512, ) output = tokenizer.decode_batch(generated.tolist())[0] submission.write(f"{news_id},{output}" + "\n") print(news_id, output)
class BERT2BERTTrainer(pl.LightningModule): def __init__(self, lr, **args): super(BERT2BERTTrainer, self).__init__() self.save_hyperparameters() encoder = BertGenerationEncoder.from_pretrained( "ckiplab/bert-base-chinese", bos_token_id=101, eos_token_id=102, # force_download=True ) decoder = BertGenerationDecoder.from_pretrained( "ckiplab/bert-base-chinese", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) self.bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder) if args['with_keywords_loss']: self.loss_fct2 = KeywordsLoss(alpha=args['keywords_loss_alpha'], loss_fct=args['keywords_loss_fct']) def generate(self, inputs_ids, attention_mask=None, **kwargs): inputs_ids = inputs_ids.to(self.device) if attention_mask is not None: attention_mask = attention_mask.to(self.device) with torch.no_grad(): return self.bert2bert.generate(input_ids=inputs_ids, attention_mask=attention_mask, bos_token_id=101, min_length=100, eos_token_id=102, pad_token_id=0, **kwargs).detach().cpu().tolist() def forward(self, inputs): with torch.no_grad(): return self.bert2bert(**inputs) def training_step(self, inputs, batch_idx): title, body = inputs title['input_ids'] = title['input_ids'].squeeze(1) title['attention_mask'] = title['attention_mask'].squeeze(1) body['input_ids'] = body['input_ids'].squeeze(1) body['attention_mask'] = body['attention_mask'].squeeze(1) ret = self.bert2bert(input_ids=title['input_ids'], attention_mask=title['attention_mask'], decoder_input_ids=body['input_ids'], decoder_attention_mask=body['attention_mask'], labels=body['input_ids']) loss2 = self.loss_fct2( ret.logits, title['input_ids']) if self.hparams['with_keywords_loss'] else 0. self.log('keyword_loss', loss2, prog_bar=True) self.log('clm_loss', ret.loss, prog_bar=True) return {'loss': ret.loss + loss2, 'keyword_loss': loss2} def training_epoch_end(self, outputs): mean_loss = torch.stack([x['loss'] for x in outputs]).reshape(-1).mean() self.log('mean_loss', mean_loss) def configure_optimizers(self): opt = optim.AdamW(self.bert2bert.parameters(), lr=self.hparams['lr']) return opt @staticmethod def add_parser_args(parser): # parser.add_argument('--lr', type=float) parser.add_argument('--with_keywords_loss', action='store_true') parser.add_argument('--keywords_loss_alpha', type=float, default=0.7, help='float > 0.5') parser.add_argument('--keywords_loss_fct', type=str, default='kldiv', help='kldiv or mse') return parser
def encoder_decoder_example(): from transformers import EncoderDecoderConfig, EncoderDecoderModel from transformers import BertConfig, GPT2Config pretrained_model_name = 'bert-base-uncased' #pretrained_model_name = 'gpt2' if 'bert' in pretrained_model_name: # Initialize a BERT bert-base-uncased style configuration. config_encoder, config_decoder = BertConfig(), BertConfig() elif 'gpt2' in pretrained_model_name: config_encoder, config_decoder = GPT2Config(), GPT2Config() else: print('Invalid model, {}.'.format(pretrained_model_name)) return config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) if 'bert' in pretrained_model_name: # Initialize a Bert2Bert model from the bert-base-uncased style configurations. model = EncoderDecoderModel(config=config) #model = EncoderDecoderModel.from_encoder_decoder_pretrained(pretrained_model_name, pretrained_model_name) # Initialize Bert2Bert from pre-trained checkpoints. tokenizer = BertTokenizer.from_pretrained(pretrained_model_name) elif 'gpt2' in pretrained_model_name: model = EncoderDecoderModel(config=config) tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name) #print('Configuration of the encoder & decoder:\n{}.\n{}.'.format(model.config.encoder, model.config.decoder)) #print('Encoder type = {}, decoder type = {}.'.format(type(model.encoder), type(model.decoder))) if False: # Access the model configuration. config_encoder = model.config.encoder config_decoder = model.config.decoder # Set decoder config to causal LM. config_decoder.is_decoder = True config_decoder.add_cross_attention = True #-------------------- input_ids = torch.tensor(tokenizer.encode('Hello, my dog is cute', add_special_tokens=True)).unsqueeze(0) # Batch size 1. if False: # Forward. outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) # Train. outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids) loss, logits = outputs.loss, outputs.logits # Save the model, including its configuration. model.save_pretrained('my-model') #-------------------- # Load model and config from pretrained folder. encoder_decoder_config = EncoderDecoderConfig.from_pretrained('my-model') model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config) #-------------------- # Generate. # REF [site] >> # https://huggingface.co/transformers/internal/generation_utils.html # https://huggingface.co/blog/how-to-generate generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id) #generated = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2, num_return_sequences=5, do_sample=True, top_k=0, temperature=0.7, early_stopping=True, decoder_start_token_id=model.config.decoder.pad_token_id) print('Generated = {}.'.format(tokenizer.decode(generated[0], skip_special_tokens=True)))
is_decoder=True, bos_token_id=101, eos_token_id=102) model = EncoderDecoderModel(encoder=encoder, decoder=decoder) # encode context the generation is conditioned on input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='pt') # Activate beam search and early_stopping. # A simple remedy is to introduce n-grams (a.k.a word sequences of n words) penalties # as introduced by Paulus et al. (2017) and Klein et al. (2017). # The most common n-grams penalty makes sure that no n-gram appears twice by # manually setting the probability of next words that could create an already seen n-gram to 0. beam_output = model.generate( input_ids, max_length=50, num_beams=5, early_stopping=True, no_repeat_ngram_size=2, num_return_sequences=5, ) print(f"Output ({beam_output.shape}): {beam_output}") print( f"Detokenized[0]: `{tokenizer.decode(beam_output[0], skip_special_tokens=False)}`" ) print( f"Detokenized[0] without special tokens: `{tokenizer.decode(beam_output[0], skip_special_tokens=True)}`" )
# Update mode. optimizer.step() print("*"*50, "Sanity check at the end of Epoch", epoch, "*"*50) # Sample. command = "Separate the given stack to form blue, red and yellow blocks stack." orig_plan = "approach_obj(yellow_block),grasp_obj_on_red_block(yellow_block),lift_obj_from_red_block(yellow_block),place_on_center(yellow_block),approach_obj(red_block),grasp_obj(red_block),lift_obj_from_tabletop(red_block),align_red_block_with(blue_block),stack_red_block_on(blue_block),approach_obj(green_block),grasp_obj(green_block),lift_obj_from_far(green_block),place_on_center(green_block),approach_obj(yellow_block),grasp_obj(yellow_block),lift_obj_from_tabletop(yellow_block),align_yellow_block_with(red_block),stack_yellow_block_on(red_block),go_home(robot)" plan = SierraDataset.process_plan(orig_plan, return_string=True) #action = SierraDataset.process_plans(plan, return_string=True) print("Command: ", command) print("Target: ", plan) # Tokenize inputs and labels. inputs = encoder_tokenizer(command, add_special_tokens=True, return_tensors="pt") print("Inputs tokenized: ", inputs) plan_tokenized = decoder_tokenizer(plan, add_special_tokens=True, return_tensors="pt") print("Target tokenized: ", plan_tokenized) print(f"\nTarget: `{decoder_tokenizer.decode(plan_tokenized.input_ids[0], skip_special_tokens=False)}`\n") # Move inputs to GPU. for key,item in inputs.items(): if type(item).__name__ == "Tensor": inputs[key] = item.cuda() # Generate output: greedy_output = bert2bert.generate(inputs.input_ids, max_length=200) #print(f"Output ({greedy_output.shape}): {greedy_output}") print(f"\nModel prediction: `{decoder_tokenizer.decode(greedy_output[0], skip_special_tokens=False)}`\n")
class BERT2BERT(ConditionalGenerator): r"""The BertGeneration model is a BERT model that can be leveraged for sequence-to-sequence tasks using EncoderDecoderModel. """ def __init__(self, config, dataset): super(BERT2BERT, self).__init__(config, dataset) self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') self.encoder_configure = BertConfig.from_pretrained('bert-base-cased') self.decoder_configure = BertConfig.from_pretrained('bert-base-cased') self.encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs( encoder_config=self.encoder_configure, decoder_config=self.decoder_configure) self.encoder = BertGenerationEncoder.from_pretrained('bert-base-cased', bos_token_id=101, eos_token_id=102) self.decoder = BertGenerationDecoder.from_pretrained( 'bert-base-cased', add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) self.encoder_decoder = EncoderDecoderModel( encoder=self.encoder, decoder=self.decoder, config=self.encoder_decoder_config) self.sos_token = dataset.sos_token self.eos_token = dataset.eos_token self.padding_token_idx = self.tokenizer.pad_token_id self.max_source_length = config['source_max_seq_length'] self.max_target_length = config['target_max_seq_length'] self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none') def generate(self, eval_dataloader): generate_corpus = [] with torch.no_grad(): for batch_data in eval_dataloader: source_text = batch_data["source_text"] for text in source_text: sentence = ' '.join(text) encoding_dict = self.tokenizer(sentence, return_tensors="pt", add_special_tokens=False) input_ids = encoding_dict['input_ids'].to(self.device) sample_outputs = self.encoder_decoder.generate( input_ids, num_beams=4, max_length=self.max_target_length, early_stopping=True, bos_token_id=101, eos_token_id=102) generated_text = [ self.tokenizer.decode(sample, skip_special_tokens=True) for sample in sample_outputs ] generated_text = [ text.lower().split() for text in generated_text ] generate_corpus.extend(generated_text) return generate_corpus def calculate_loss(self, corpus, epoch_idx=-1): source_text = corpus['source_text'] target_text = corpus['target_text'] input_ids = [] attn_masks = [] for text in source_text: sentence = ' '.join(text) encoding_dict = self.tokenizer(sentence, max_length=self.max_source_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False) input_ids.append(encoding_dict['input_ids']) attn_masks.append(encoding_dict['attention_mask']) input_ids = torch.cat(input_ids, dim=0).to(self.device) attn_masks = torch.cat(attn_masks, dim=0).to(self.device) target_ids = [] for text in target_text: sentence = ' '.join(text) encoding_dict = self.tokenizer(sentence, max_length=self.max_target_length, padding="max_length", truncation=True, return_tensors="pt") target_ids.append(encoding_dict['input_ids']) target_ids = torch.cat(target_ids, dim=0).to(self.device) target_ids = target_ids.contiguous() outputs = self.encoder_decoder(input_ids, attention_mask=attn_masks, decoder_input_ids=target_ids, labels=target_ids) token_logits = outputs[1] loss = self.loss(token_logits.view(-1, token_logits.size(-1)), target_ids.view(-1)) loss = loss.reshape_as(target_ids) length = (target_ids != self.padding_token_idx).sum(dim=1).float() loss = loss.sum(dim=1) / length.float() return loss.mean()
class CrossVLGenerator(BaseModel): def __init__(self, config): super().__init__(config) self.build() @classmethod def config_path(cls): return "configs/models/cvlg/defaults.yaml" def build(self): # to be further set # breakpoint() self.image_feature_module = build_image_encoder( self.config.image_feature_processor, direct_features=True ) if self.config.concate_trace: self.trace_feature_module = build_encoder(self.config.trace_feature_encoder) if self.config.base_model_name == "bert-base-uncased": self.encoderdecoder = EncoderDecoderModel.from_encoder_decoder_pretrained( "bert-base-uncased", "bert-base-uncased" ) elif self.config.base_model_name == "2layer-base": config_encoder = BertConfig() config_decoder = BertConfig() config_encoder.max_position_embeddings = 1090 config_encoder.num_hidden_layers = 2 config_decoder.num_hidden_layers = 2 self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs( config_encoder, config_decoder ) self.encoderdecoder = EncoderDecoderModel(config=self.codec_config) elif self.config.base_model_name == "3layer-base": config_encoder = BertConfig() config_decoder = BertConfig() config_encoder.num_hidden_layers = 3 config_decoder.num_hidden_layers = 3 self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs( config_encoder, config_decoder ) self.encoderdecoder = EncoderDecoderModel(config=self.codec_config) if self.config.loop_contrastive: self.trace_caption_contrastive = TraceCaptionContrastiveModel( self.config.tc_contrastive_aggregate_method ) if ( hasattr(self.config, "pretrans_attention") and self.config.pretrans_attention ): # import ipdb; ipdb.set_trace() tempconf = self.encoderdecoder.config.encoder num_heads = tempconf.num_attention_heads num_layers = tempconf.num_hidden_layers self.attention_trans = AttentionTransform(num_layers, num_heads, 100) self.BOS_ID = 101 self.vae = OpenAIDiscreteVAE() image_code_dim = 768 image_fmap_size = self.vae.image_size // (2 ** self.vae.num_layers) self.image_seq_len = image_fmap_size ** 2 self.image_emb = torch.nn.Embedding(self.vae.num_tokens, image_code_dim) self.image_pos_emb = AxialPositionalEmbedding( image_code_dim, axial_shape=(image_fmap_size, image_fmap_size) ) def forward(self, sample_list, *args, **kwargs): # breakpoint() # import ipdb; ipdb.set_trace() visual_code = self.vae.get_codebook_indices(sample_list["image"]) visual_emb = self.image_emb(visual_code) visual_emb += self.image_pos_emb(visual_emb) decoder_input_ids = sample_list["input_ids"][:, :-1] # using default mask # target_mask = sample_list["input_mask"] # segment_ids = sample_list["segment_ids"] # token_attends = sample_list["token_attends"] other_kwargs = {} # if self.config.image_feature_processor.type == "spatial": # bbox_feature = sample_list["image_feature_0"] # spatial_feature = sample_list["image_info_0"]["bbox"] # inputs_embeds = self.image_feature_module(bbox_feature, spatial_feature) # else: # bbox_feature = sample_list["image_feature_0"] # inputs_embeds = self.image_feature_module(bbox_feature) # if hasattr(self.config, "no_vision") and self.config.no_vision: # inputs_embeds = inputs_embeds * 0 inputs_embeds = visual_emb batch_size = inputs_embeds.shape[0] if self.config.concate_trace: trace_boxes = sample_list["trace_boxes"] trace_boxes_mask = sample_list["trace_boxes_mask"] trace_feature = self.trace_feature_module(trace_boxes) trace_seg_id = sample_list["trace_boxes_seg_id"] inputs_embeds = torch.cat((inputs_embeds, trace_feature), dim=1) image_feats_mask = trace_boxes_mask.new_ones( (batch_size, visual_code.shape[1]) ) image_feats_seg_id = trace_seg_id.new_zeros( (batch_size, visual_code.shape[1]) ) attention_mask = torch.cat((image_feats_mask, trace_boxes_mask), dim=1) token_type_ids = torch.cat((image_feats_seg_id, trace_seg_id), dim=1) position_ids = trace_seg_id.new_zeros((batch_size, attention_mask.shape[1])) other_kwargs.update( { "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, } ) if self.training: decoder_output = self.encoderdecoder( decoder_input_ids=decoder_input_ids, inputs_embeds=inputs_embeds, output_attentions=True, output_hidden_states=True, return_dict=True, **other_kwargs ) logits = decoder_output["logits"] cross_attentions = [] # import ipdb; ipdb.set_trace() for cross_attention in decoder_output["cross_attentions"]: if self.config.concate_trace: cross_attention = cross_attention[:, :, :, :100] # cross_attentions.append(cross_attention.mean(dim=1)) cross_attentions.append(cross_attention) # breakpoint() if ( hasattr(self.config, "pretrans_attention") and self.config.pretrans_attention ): cross_attentions = self.attention_trans(cross_attentions) else: cross_attentions = [crs.mean(dim=1) for crs in cross_attentions] model_output = {} model_output["captions"] = torch.max(logits, dim=-1)[1] model_output["scores"] = logits model_output["cross_attentions"] = cross_attentions sample_list["targets"] = sample_list["input_ids"][:, 1:] if self.config.loop_contrastive: cap_feat, vision_trace_feat = self.trace_caption_contrastive( decoder_output["encoder_hidden_states"][-1], sample_list["trace_boxes_loop_contrastive_seg_id"], decoder_output["decoder_hidden_states"][-1], sample_list["segment_ids"], ) model_output["contrastive_a"] = cap_feat model_output["contrastive_b"] = vision_trace_feat else: if self.config.inference.type == "beam_search": generate_output = self.encoderdecoder.generate( input_ids=None, input_embeds=inputs_embeds, bos_token_id=self.BOS_ID, decoder_start_token_id=self.BOS_ID, **self.config.inference.args, **other_kwargs ) elif self.config.inference.type == "greedy": generate_output = self.encoderdecoder.generate( input_ids=None, input_embeds=inputs_embeds, max_length=self.config.max_gen_length, bos_token_id=self.BOS_ID, decoder_start_token_id=self.BOS_ID, **other_kwargs ) elif self.config.inference.type == "nucleus_sampling": generate_output = self.encoderdecoder.generate( input_ids=None, input_embeds=inputs_embeds, bos_token_id=self.BOS_ID, decoder_start_token_id=self.BOS_ID, **self.config.inference.args, **other_kwargs ) model_output = {} # breakpoint() if ( "return_attention" in self.config.inference and self.config.inference.return_attention ): with torch.no_grad(): attention_temp_output = self.encoderdecoder( decoder_input_ids=generate_output, inputs_embeds=inputs_embeds, output_attentions=True, return_dict=True, ) cross_attentions = [] for cross_attention in attention_temp_output["cross_attentions"]: if self.config.concate_trace: cross_attention = cross_attention[:, :, :, :100] cross_attentions.append(cross_attention.mean(dim=1)) # breakpoint() cross_attentions = ( torch.stack(cross_attentions).max(dim=0)[0].max(dim=-1)[1] ) model_output["cross_attention"] = cross_attentions # breakpoint() model_output["captions"] = generate_output model_output["losses"] = {} loss_key = "{}/{}".format( sample_list.dataset_name, sample_list.dataset_type ) # Add a dummy loss so that loss calculation is not required model_output["losses"][loss_key + "/dummy_loss"] = torch.zeros( 1, device=sample_list.image_feature_0.device ) # breakpoint() return model_output
class BERT2BERT(Seq2SeqGenerator): r"""The BertGeneration model is a BERT model that can be leveraged for sequence-to-sequence tasks using EncoderDecoderModel. """ def __init__(self, config, dataset): super(BERT2BERT, self).__init__(config, dataset) self.sos_token_idx = 101 self.eos_token_idx = 102 self.max_source_length = dataset.max_source_length self.max_target_length = dataset.max_target_length self.pretrained_model_path = config['pretrained_model_path'] self.tokenizer = BertTokenizer.from_pretrained( self.pretrained_model_path) self.encoder_configure = BertConfig.from_pretrained( self.pretrained_model_path) self.decoder_configure = BertConfig.from_pretrained( self.pretrained_model_path) self.encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs( encoder_config=self.encoder_configure, decoder_config=self.decoder_configure) self.encoder = BertGenerationEncoder.from_pretrained( self.pretrained_model_path, bos_token_id=self.sos_token_idx, eos_token_id=self.eos_token_idx) self.decoder = BertGenerationDecoder.from_pretrained( self.pretrained_model_path, bos_token_id=self.sos_token_idx, eos_token_id=self.eos_token_idx, add_cross_attention=True, is_decoder=True) self.encoder_decoder = EncoderDecoderModel( encoder=self.encoder, decoder=self.decoder, config=self.encoder_decoder_config) self.padding_token_idx = self.tokenizer.pad_token_id self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none') def generate(self, batch_data, eval_data): generate_corpus = [] source_text = batch_data["source_text"] for text in source_text: sentence = ' '.join(text) encoding_dict = self.tokenizer(sentence, return_tensors="pt", add_special_tokens=False) input_ids = encoding_dict['input_ids'].to(self.device) sample_outputs = self.encoder_decoder.generate( input_ids, num_beams=5, max_length=self.max_target_length, early_stopping=True, bos_token_id=self.sos_token_idx, eos_token_id=self.eos_token_idx) generated_text = [ self.tokenizer.decode(sample, skip_special_tokens=True) for sample in sample_outputs ] generated_text = [text.lower().split() for text in generated_text] generate_corpus.extend(generated_text) return generate_corpus def forward(self, corpus, epoch_idx=-1): source_text = corpus['source_text'] target_text = corpus['target_text'] input_ids = [] encoder_attn_masks = [] for text in source_text: sentence = ' '.join(text) encoding_dict = self.tokenizer(sentence, max_length=self.max_source_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False) input_ids.append(encoding_dict['input_ids']) encoder_attn_masks.append(encoding_dict['attention_mask']) input_ids = torch.cat(input_ids, dim=0).to(self.device) encoder_attn_masks = torch.cat(encoder_attn_masks, dim=0).to(self.device) target_ids = [] decoder_attn_masks = [] for text in target_text: sentence = ' '.join(text) decoding_dict = self.tokenizer(sentence, max_length=self.max_target_length, padding="max_length", truncation=True, return_tensors="pt") target_ids.append(decoding_dict['input_ids']) decoder_attn_masks.append(decoding_dict['attention_mask']) target_ids = torch.cat(target_ids, dim=0).to(self.device) decoder_attn_masks = torch.cat(decoder_attn_masks, dim=0).to(self.device) decoder_input_ids = target_ids[:, :-1].contiguous() decoder_attn_masks = decoder_attn_masks[:, :-1].contiguous() decoder_target_ids = target_ids[:, 1:].contiguous() outputs = self.encoder_decoder( input_ids, attention_mask=encoder_attn_masks, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attn_masks, use_cache=False) token_logits = outputs.logits loss = self.loss(token_logits.view(-1, token_logits.size(-1)), decoder_target_ids.view(-1)) loss = loss.reshape_as(decoder_target_ids) length = (decoder_target_ids != self.padding_token_idx).sum( dim=1).float() loss = loss.sum(dim=1) / length.float() return loss.mean()
mean_train_loss = sum(train_loss_list) / (len(train_loss_list) + 1e-4) mean_train_acc = sum(train_acc_list) / (len(train_acc_list) + 1e-4) print("epoch: {} train_loss: {:.3f}, train_acc: {:.3f}".format(epoch, mean_train_loss, mean_train_acc)) if epoch % 5 == 0: model.eval() valid_all_match = [] for tasks, plans in tqdm(valid_seen_dataloader): try: tokenized_text = encoder_tokenizer(tasks, padding=True, truncation=True, max_length=100, return_tensors="pt").input_ids if args.gpu and torch.cuda.is_available(): tokenized_text = tokenized_text.to("cuda") output_labels = model.generate(tokenized_text, decoder_start_token_id=1) ouput_array = output_labels.cpu().numpy()[0] targets = decoder_tokenizer.tokenize(plans) all_match = np.all(ouput_array[:len(targets[0])] == targets[0]) valid_all_match.append(1 if all_match else 0) except: valid_all_match.append(0) print("epoch: {} valid_all_match: {:.3f}".format(epoch, np.mean(valid_all_match))) end_time = datetime.now() print("Finished Training:", end_time) print("Time elapsed (in hour): {:.2f}".format((end_time - begin_time).total_seconds() / 3600))
print("Command: ", command) print("Target: ", goals) # Tokenize inputs and labels. inputs = encoder_tokenizer(command, add_special_tokens=True, return_tensors="pt") print("Inputs tokenized: ", inputs) goals_tokenized = decoder_tokenizer(goals, add_special_tokens=add_special_tokens, return_tensors="pt") print("Target tokenized: ", goals_tokenized) print( f"\nTarget: `{decoder_tokenizer.decode(goals_tokenized.input_ids[0], skip_special_tokens=False)}`\n" ) # Move inputs to GPU. for key, item in inputs.items(): if type(item).__name__ == "Tensor": inputs[key] = item.cuda() # Generate output: greedy_output = bert2bert.generate(inputs.input_ids, max_length=(sierra_ds.max_goals_length + 2)) #print(f"Output ({greedy_output.shape}): {greedy_output}") print( f"\nModel prediction: `{decoder_tokenizer.decode(greedy_output[0], skip_special_tokens=False)}`\n" )
class PhonetizerModel: phon_tokenizer = { 'e': 7, 'i': 8, 'R': 9, 'a': 10, 'o': 11, 't': 12, 's': 13, 'l': 14, 'k': 15, 'p': 16, 'm': 17, 'n': 18, 'd': 19, 'y': 20, '@': 21, 'f': 22, 'z': 23, 'b': 24, '§': 25, 'v': 26, '2': 27, '1': 28, 'Z': 29, 'g': 30, 'u': 31, 'S': 32 } phon_untokenizer = {v: k for k, v in phon_tokenizer.items()} char_tokenizer = { 'e': 7, 'i': 8, 'a': 9, 'r': 10, 'o': 11, 's': 12, 't': 13, 'n': 14, 'l': 15, 'é': 16, 'c': 17, 'p': 18, 'u': 19, 'm': 20, 'd': 21, '-': 22, 'h': 23, 'g': 24, 'b': 25, 'v': 26, 'f': 27, 'k': 28, 'y': 29, 'x': 30, 'è': 31, 'ï': 32, 'j': 33, 'z': 34, 'w': 35, 'q': 36 } def __init__(self, device='cpu', model=None): vocabsize = 37 max_length = 50 encoder_config = BertConfig(vocab_size=vocabsize, max_position_embeddings=max_length + 64, num_attention_heads=4, num_hidden_layers=4, hidden_size=128, type_vocab_size=1) encoder = BertModel(config=encoder_config) vocabsize = 33 max_length = 50 decoder_config = BertConfig(vocab_size=vocabsize, max_position_embeddings=max_length + 64, num_attention_heads=4, num_hidden_layers=4, hidden_size=128, type_vocab_size=1, add_cross_attentions=True, is_decoder=True) decoder_config.add_cross_attention = True decoder = BertLMHeadModel(config=decoder_config) # Define encoder decoder model self.model = EncoderDecoderModel(encoder=encoder, decoder=decoder) self.model.to(device) self.device = device if model is not None: self.model.load_state_dict(torch.load(model)) def phonetize(self, word): word = word.replace('à', 'a') word = word.replace('û', 'u') word = word.replace('ù', 'u') word = word.replace('î', 'i') word = word.replace('ç', 'ss') word = word.replace('ô', 'o') word = word.replace('â', 'a') word = word.replace('qu', 'k') word = word.replace('ê', 'e') assert set(word).issubset(set(PhonetizerModel.char_tokenizer.keys())) encoded = torch.tensor( [0] + [PhonetizerModel.char_tokenizer[p] for p in word] + [2]) output = self.model.generate( encoded.unsqueeze(0).to(self.device), max_length=50, decoder_start_token_id=0, eos_token_id=2, pad_token_id=1, ).detach().cpu().numpy()[0] bound = np.where(output == 2)[0][0] if 2 in output else 1000 phon_pred = ''.join([ PhonetizerModel.phon_untokenizer[c] for c in output[:bound] if c > 6 ]) return phon_pred def check_phonetization_error(self, word, phon): prediction = self.phonetize(word)[:5] score = pairwise2.align.globalms(list(phon[:5]), list(prediction), 2, -1, -1, -.5, score_only=True, gap_char=['-']) / len(phon[:5]) return score