def __init__(self, args: dict, vocab: Vocabulary, model: Module,
                 optimizer: optim, output_path: str) -> None:
        super().__init__()

        if vocab.get_vocab_size() != args['vocab_size']:
            raise ValueError('vocabulary has not been initiated.')

        self.args = args
        self.vocab = vocab
        self.model = model.to(get_device_setting())
        self.optimizer = optimizer
        self.output_path = output_path
        self.train_loss = []

        # define what loss function is.
        self.loss_fn = CrossEntropyLoss(ignore_index=self.args['pad_idx']).to(
            get_device_setting())
Ejemplo n.º 2
0
    def forward(self, question: T.Tensor, answer: T.Tensor) -> T.Tensor:
        """
        :param question: [bs x seq_len]
        :param answer: [bs x seq_len]
        :return:
        """
        q_embed = self.embedding(question.long().to(get_device_setting()))
        a_embed = self.embedding(answer.long().to(get_device_setting()))

        q_mask = question == self.args['pad_idx']
        a_mask = answer == self.args['pad_idx']
        mem_q_mask = q_mask.clone()
        tgt_mask = self.transformer.generate_square_subsequent_mask(
            answer.size(1))

        # [seq x bs x dim] -> [bs x seq x dim]
        q_embed = T.einsum('ijk->jik', q_embed)
        a_embed = T.einsum('ijk->jik', a_embed)

        attn = self.transformer(
            src=q_embed,
            tgt=a_embed,
            src_key_padding_mask=q_mask.to(get_device_setting()),
            tgt_key_padding_mask=a_mask.to(get_device_setting()),
            memory_key_padding_mask=mem_q_mask.to(get_device_setting()),
            tgt_mask=tgt_mask.to(get_device_setting()))

        attn = T.einsum('ijk->jik', attn)
        logits = self.projection(attn)
        print('logits shape ', logits.shape)

        return logits
    def evaluate(self, iteration: int, valid_loader: DataLoader):
        def decode_sequences(question, prediction, answer):
            question_ids = question.tolist()
            pred_ids = prediction.max(dim=-1)[1].tolist()
            answer_ids = answer.tolist()

            decoded_question = []
            decoded_prediction = []
            decoded_answer = []

            for questions in question_ids:
                seq = ' '.join([
                    self.vocab.idx2token[question_id]
                    for question_id in questions
                ]).replace('<pad>', '').strip()
                decoded_question.append(seq)

            for preds in pred_ids:
                seq = ' '.join([
                    self.vocab.idx2token[pred_id] for pred_id in preds
                ]).replace('<eos>', '').replace('<pad>', '').strip()
                decoded_prediction.append(seq)

            for answers in answer_ids:
                seq = ' '.join([
                    self.vocab.idx2token[answer_id] for answer_id in answers
                ]).replace('<eos>', '').replace('<pad>', '').strip()
                decoded_answer.append(seq)

            for q, p, a in zip(decoded_question, decoded_prediction,
                               decoded_answer):
                print('********** decoded result **********')
                print(q + '\n')
                print(p + '\n')
                print(a + '\n')

        print(f'********** evaluating start **********')

        self.model.eval()

        for i, (question, answer) in tqdm(enumerate(valid_loader)):
            output = self.model(question, answer)
            decode_sequences(question, output, answer)

            output = output.view(-1, output.size(-1))
            answer = answer.view(-1).long()

            loss = self.loss_fn(output, answer.to(get_device_setting()))
            print(f'********** evaluating loss: {loss.item()} **********')
    def train(self, epoch: int, train_loader: DataLoader,
              valid_loader: DataLoader) -> None:

        # set the model to train mode.
        self.model.train()

        for ep_iter in tqdm(range(1, epoch + 1)):
            print(f'********** epoch number: {ep_iter} **********')
            for i, (question, answer) in tqdm(enumerate(train_loader)):
                self.optimizer.zero_grad()
                output = self.model(question, answer)
                print(output.shape)
                output = output.view(-1, output.size(-1))
                answer = answer.view(-1).long()
                loss = self.loss_fn(output, answer.to(get_device_setting()))
                print(f'********** training loss: {loss.item()} **********')
                loss.backward()
                self.optimizer.step()

                if (i + 1) % 100 == 0 and epoch > 2:
                    self.evaluate(i, valid_loader)
                    T.save(self.model.state_dict(),
                           self.output_path + f'model-{ep_iter}-{i}.pt')
Ejemplo n.º 5
0
    parser.add_argument('--output_path', type=str, required=True)
    parser.add_argument('--max_len', type=int, required=True)
    parser.add_argument('--batch_size', type=int, required=True)

    args = parser.parse_args()

    data_path = args.data_path
    output_path = args.output_path
    max_len = args.max_len
    bs = args.batch_size

    train, valid, train_y, valid_y, corpus = load_data(data_path)
    vocab = Vocabulary(corpus)
    vocab.build_vocab()

    model_args = get_base_config()
    model_args['max_len'] = max_len

    train_loader = get_loader(train, train_y, vocab, max_len, bs, True)
    valid_loader = get_loader(valid, valid_y, vocab, max_len, bs, True)

    model = Net(model_args).to(get_device_setting())
    optimizer = optim.Adam(params=model.parameters(), lr=model_args['lr'])
    trainer = Trainer(model_args, vocab, model, optimizer, output_path)

    print('********** trainer object has been initiated **********')
    print(model)
    print('********** trainer object has been initiated **********')

    trainer.train(10, train_loader, valid_loader)