def __init__(self, args: dict, vocab: Vocabulary, model: Module, optimizer: optim, output_path: str) -> None: super().__init__() if vocab.get_vocab_size() != args['vocab_size']: raise ValueError('vocabulary has not been initiated.') self.args = args self.vocab = vocab self.model = model.to(get_device_setting()) self.optimizer = optimizer self.output_path = output_path self.train_loss = [] # define what loss function is. self.loss_fn = CrossEntropyLoss(ignore_index=self.args['pad_idx']).to( get_device_setting())
def forward(self, question: T.Tensor, answer: T.Tensor) -> T.Tensor: """ :param question: [bs x seq_len] :param answer: [bs x seq_len] :return: """ q_embed = self.embedding(question.long().to(get_device_setting())) a_embed = self.embedding(answer.long().to(get_device_setting())) q_mask = question == self.args['pad_idx'] a_mask = answer == self.args['pad_idx'] mem_q_mask = q_mask.clone() tgt_mask = self.transformer.generate_square_subsequent_mask( answer.size(1)) # [seq x bs x dim] -> [bs x seq x dim] q_embed = T.einsum('ijk->jik', q_embed) a_embed = T.einsum('ijk->jik', a_embed) attn = self.transformer( src=q_embed, tgt=a_embed, src_key_padding_mask=q_mask.to(get_device_setting()), tgt_key_padding_mask=a_mask.to(get_device_setting()), memory_key_padding_mask=mem_q_mask.to(get_device_setting()), tgt_mask=tgt_mask.to(get_device_setting())) attn = T.einsum('ijk->jik', attn) logits = self.projection(attn) print('logits shape ', logits.shape) return logits
def evaluate(self, iteration: int, valid_loader: DataLoader): def decode_sequences(question, prediction, answer): question_ids = question.tolist() pred_ids = prediction.max(dim=-1)[1].tolist() answer_ids = answer.tolist() decoded_question = [] decoded_prediction = [] decoded_answer = [] for questions in question_ids: seq = ' '.join([ self.vocab.idx2token[question_id] for question_id in questions ]).replace('<pad>', '').strip() decoded_question.append(seq) for preds in pred_ids: seq = ' '.join([ self.vocab.idx2token[pred_id] for pred_id in preds ]).replace('<eos>', '').replace('<pad>', '').strip() decoded_prediction.append(seq) for answers in answer_ids: seq = ' '.join([ self.vocab.idx2token[answer_id] for answer_id in answers ]).replace('<eos>', '').replace('<pad>', '').strip() decoded_answer.append(seq) for q, p, a in zip(decoded_question, decoded_prediction, decoded_answer): print('********** decoded result **********') print(q + '\n') print(p + '\n') print(a + '\n') print(f'********** evaluating start **********') self.model.eval() for i, (question, answer) in tqdm(enumerate(valid_loader)): output = self.model(question, answer) decode_sequences(question, output, answer) output = output.view(-1, output.size(-1)) answer = answer.view(-1).long() loss = self.loss_fn(output, answer.to(get_device_setting())) print(f'********** evaluating loss: {loss.item()} **********')
def train(self, epoch: int, train_loader: DataLoader, valid_loader: DataLoader) -> None: # set the model to train mode. self.model.train() for ep_iter in tqdm(range(1, epoch + 1)): print(f'********** epoch number: {ep_iter} **********') for i, (question, answer) in tqdm(enumerate(train_loader)): self.optimizer.zero_grad() output = self.model(question, answer) print(output.shape) output = output.view(-1, output.size(-1)) answer = answer.view(-1).long() loss = self.loss_fn(output, answer.to(get_device_setting())) print(f'********** training loss: {loss.item()} **********') loss.backward() self.optimizer.step() if (i + 1) % 100 == 0 and epoch > 2: self.evaluate(i, valid_loader) T.save(self.model.state_dict(), self.output_path + f'model-{ep_iter}-{i}.pt')
parser.add_argument('--output_path', type=str, required=True) parser.add_argument('--max_len', type=int, required=True) parser.add_argument('--batch_size', type=int, required=True) args = parser.parse_args() data_path = args.data_path output_path = args.output_path max_len = args.max_len bs = args.batch_size train, valid, train_y, valid_y, corpus = load_data(data_path) vocab = Vocabulary(corpus) vocab.build_vocab() model_args = get_base_config() model_args['max_len'] = max_len train_loader = get_loader(train, train_y, vocab, max_len, bs, True) valid_loader = get_loader(valid, valid_y, vocab, max_len, bs, True) model = Net(model_args).to(get_device_setting()) optimizer = optim.Adam(params=model.parameters(), lr=model_args['lr']) trainer = Trainer(model_args, vocab, model, optimizer, output_path) print('********** trainer object has been initiated **********') print(model) print('********** trainer object has been initiated **********') trainer.train(10, train_loader, valid_loader)