def train(self, train_loader: DataLoader,
              valid_loader: DataLoader) -> None:
        self.model.train()
        global_step = 0
        max_grad_norm = 2.0

        for ep in tqdm(range(1, self.args['epoch'] + 1)):
            for i, (input_ids, segment_ids, attn_masks,
                    labels) in tqdm(enumerate(train_loader)):
                global_step += 1
                self.optim.zero_grad()
                logits = self.model(False, input_ids.to(get_device_setting()),
                                    segment_ids.to(get_device_setting()),
                                    attn_masks.to(get_device_setting()))

                labels = labels.squeeze(-1).float().to(get_device_setting())

                loss = self.loss_fn(logits, labels)
                loss = loss.mean()

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         max_grad_norm)
                self.optim.step()

                if global_step % 100 == 0:
                    #                     print(f'************** Training Total Loss : {loss.item()} ******************')
                    self.writer.add_scalar('Train/Total_Loss', loss.item(),
                                           global_step)

                if global_step % 10000 == 0:
                    self.evaluate(valid_loader, 'Valid', global_step)
                    torch.save(self.model.state_dict(),
                               f'./rsc/output/bert-cls-v3-{global_step}.pth')
Example #2
0
    def evaluate(self, valid_loader: DataLoader, mode: str, gs: int):
        self.model.eval()
        total_count = 0
        recall_result = [0, 0, 0, 0, 0]  # 1, 2, 3, 5, 7
        k = [1, 2, 3, 5, 7]

        with torch.no_grad():
            for idx, (input_ids, segment_ids, attn_masks,
                      labels) in tqdm(enumerate(valid_loader)):
                loss, preds, labels = self.model(
                    True, input_ids.to(get_device_setting()),
                    segment_ids.to(get_device_setting()),
                    attn_masks.to(get_device_setting()),
                    labels.to(get_device_setting()))

                bs = input_ids.size(0)

                total_count += bs
                recall = [self.evaluate_recall(preds, labels, k=i) for i in k]

                for i, r_k in enumerate(recall):
                    recall_result[i] += r_k

                if idx % 10 == 0:
                    for i, r_k in zip(k, recall_result):
                        print(f'Recall @ {i} : {r_k / float(total_count)}')

            for i, r_k in zip(k, recall_result):
                self.writer.add_scalar(f'{mode}/Recall@{i}',
                                       (r_k / float(total_count)), gs)

        self.model.train()
Example #3
0
    def __getitem__(self, idx) -> Any:
        if self.eval == False:
            if len(self.ctx[idx]) < self.ctx_max_len:
                composition = self.tokenizer.encode_plus(
                    self.ctx[idx],
                    self.utter[idx],
                    add_special_tokens=True,
                    pad_to_max_length=True,
                    max_length=self.ctx_max_len + self.utter_max_len
                )
            else:
                composition = self.tokenizer.encode_plus(
                    self.ctx[idx][:self.ctx_max_len-2],
                    self.utter[idx],
                    add_special_tokens=True,
                    pad_to_max_length=True,
                    max_length=self.ctx_max_len + self.utter_max_len
                )

            input_ids = torch.LongTensor(composition['input_ids'])
            segment_ids = torch.LongTensor(composition['token_type_ids'])
            attn_masks = torch.LongTensor(composition['attention_mask'])
            labels = torch.LongTensor([int(not self.label[idx])])

            return input_ids.to(get_device_setting()), segment_ids.to(get_device_setting()), attn_masks.to(get_device_setting()), labels.to(get_device_setting())
        elif self.eval == True:
            composition = []

            for u in self.utter[idx]:
                if len(self.ctx[idx]) < self.ctx_max_len:
                    composition.append(self.tokenizer.encode_plus(
                        self.ctx[idx],
                        u,
                        add_special_tokens=True,
                        pad_to_max_length=True,
                        max_length=self.ctx_max_len + self.utter_max_len)
                    )
                else:
                    composition.append(self.tokenizer.encode_plus(
                        self.ctx[idx][:self.ctx_max_len - 2],
                        u,
                        add_special_tokens=True,
                        pad_to_max_length=True,
                        max_length=self.ctx_max_len + self.utter_max_len)
                    )

            input_ids = torch.LongTensor([c['input_ids'] for c in composition])
            segment_ids = torch.LongTensor([c['token_type_ids'] for c in composition])
            attn_masks = torch.LongTensor([c['attention_mask'] for c in composition])
            labels = [0] + [1] * 9

            return input_ids.to(get_device_setting()), segment_ids.to(get_device_setting()), attn_masks.to(get_device_setting()), torch.LongTensor(labels).to(get_device_setting())
    def __init__(self, model: BertBaseCLS, args: dict) -> None:
        """
        Initialize Trainer Class
        :param model: BertForNextSentencePrediction
        :param args: {lr, eps, epoch, log_dir, ...}
        """

        self.model = model
        self.model = nn.DataParallel(self.model)
        self.model = self.model.to(get_device_setting())
        self.args = args
        self.writer = SummaryWriter(log_dir=args['log_dir'])
        self.optim = AdamW(self.model.parameters(),
                           lr=args['lr'],
                           eps=args['eps'])
        self.loss_fn = nn.BCEWithLogitsLoss().to(get_device_setting())
    def train(self, train_loader: DataLoader,
              valid_loader: DataLoader) -> None:
        self.model.train()
        global_step = 0
        max_grad_norm = 2.0

        correct, bs = 0, 0

        for ep in tqdm(range(1, self.args['epoch'] + 1)):
            for i, (input_ids, segment_ids, attn_masks,
                    labels) in tqdm(enumerate(train_loader)):
                global_step += 1
                self.optim.zero_grad()
                loss, preds = self.model(False,
                                         input_ids.to(get_device_setting()),
                                         segment_ids.to(get_device_setting()),
                                         attn_masks.to(get_device_setting()),
                                         labels.to(get_device_setting()))
                correct += preds
                bs += input_ids.size(0)
                acc = correct / bs

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         max_grad_norm)
                self.optim.step()

                if global_step % 10 == 0:
                    self.evaluate(valid_loader, 'Valid', global_step)

                if global_step % 100 == 0:
                    print(
                        f'************** Training Total Loss : {loss.item()} ******************'
                    )
                    print(
                        f'************** Training Accuracy  : {acc} ******************'
                    )
                    self.writer.add_scalar('Train/Total_Loss', loss.item(),
                                           global_step)
                    self.writer.add_scalar('Train/Accuracy', acc, global_step)

                if global_step % 1000 == 0:
                    self.evaluate(valid_loader, 'Valid', global_step)
                    torch.save(self.model.state_dict(),
                               f'./rsc/output/bert-nsp-{global_step}.pth')
    def evaluate(self, valid_loader: DataLoader, mode: str, gs: int):
        self.model.eval()
        k = [1, 2, 3, 5, 7]
        total_examples, total_correct = 0, 0

        with torch.no_grad():
            for idx, (input_ids, segment_ids, attn_masks,
                      labels) in tqdm(enumerate(valid_loader)):
                logits = self.model(True, input_ids.to(get_device_setting()),
                                    segment_ids.to(get_device_setting()),
                                    attn_masks.to(get_device_setting()))
                pred = torch.sigmoid(logits)
                pred = pred.cpu().detach().tolist()
                labels = labels.view(len(pred))
                labels = labels.cpu().detach().tolist()

                rank_by_pred = calculate_candidates_ranking(
                    np.array(pred), np.array(labels))

                num_correct, pos_index = logits_recall_at_k(rank_by_pred, k)
                total_correct = np.add(total_correct, num_correct)
                total_examples += rank_by_pred.shape[0]

                recall_result = ""

                if (idx + 1) % 1000 == 0:
                    for i in range(len(k)):
                        recall_result += "Recall@%s : " % k[
                            i] + "%.2f%% | " % (float(
                                (total_correct[i]) / float(total_examples)) *
                                                100)

                    print(recall_result)

            for i in range(len(k)):
                print(
                    f'Recall@{k[i]} -> {float(total_correct[i]) / float(total_examples)}'
                )
                self.writer.add_scalar(
                    f'{mode}/Recall@{k[i]}',
                    (float(total_correct[i]) / float(total_examples)), gs)

        self.model.train()
Example #7
0
    def __init__(self, model: nn.Module, args: dict) -> None:
        """
        Initialize Trainer Class
        :param model: BertForNextSentencePrediction
        :param args: {lr, eps, epoch, log_dir, ...}
        """

        self.model = model.to(get_device_setting())
        self.args = args
        self.writer = SummaryWriter(log_dir=args['log_dir'])
        self.optim = AdamW(self.model.parameters(),
                           lr=args['lr'],
                           eps=args['eps'])