Example #1
0
def train(config):
    net = BertForMaskedLM.from_pretrained(config.model)
    lossFunc = KLDivLoss(config)

    if torch.cuda.is_available():
        net = net.cuda()
        lossFunc = lossFunc.cuda()

        if config.dataParallel:
            net = DataParallelModel(net)
            lossFunc = DataParallelCriterion(lossFunc)

    options = optionsLoader(LOG, config.optionFrames, disp=False)
    Tokenizer = BertTokenizer.from_pretrained(config.model)
    prepareFunc = prepare_data

    trainSet = Dataset('train', config.batch_size,
                       lambda x: len(x[0]) + len(x[1]), prepareFunc, Tokenizer,
                       options['dataset'], LOG, 'train')
    validSet = Dataset('valid', config.batch_size,
                       lambda x: len(x[0]) + len(x[1]), prepareFunc, Tokenizer,
                       options['dataset'], LOG, 'valid')

    print(trainSet.__len__())

    Q = []
    best_vloss = 1e99
    counter = 0
    lRate = config.lRate

    prob_src = config.prob_src
    prob_tgt = config.prob_tgt

    num_train_optimization_steps = trainSet.__len__(
    ) * options['training']['stopConditions']['max_epoch']
    param_optimizer = list(net.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=lRate,
                         e=1e-9,
                         t_total=num_train_optimization_steps,
                         warmup=0.0)

    for epoch_idx in range(options['training']['stopConditions']['max_epoch']):
        total_seen = 0
        total_similar = 0
        total_unseen = 0
        total_source = 0

        trainSet.setConfig(config, prob_src, prob_tgt)
        trainLoader = data.DataLoader(dataset=trainSet,
                                      batch_size=1,
                                      shuffle=True,
                                      num_workers=config.dataLoader_workers,
                                      pin_memory=True)

        validSet.setConfig(config, 0.0, prob_tgt)
        validLoader = data.DataLoader(dataset=validSet,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=config.dataLoader_workers,
                                      pin_memory=True)

        for batch_idx, batch_data in enumerate(trainLoader):
            if (batch_idx + 1) % 10000 == 0:
                gc.collect()
            start_time = time.time()

            net.train()

            inputs, positions, token_types, labels, masks, batch_seen, batch_similar, batch_unseen, batch_source = batch_data

            inputs = inputs[0].cuda()
            positions = positions[0].cuda()
            token_types = token_types[0].cuda()
            labels = labels[0].cuda()
            masks = masks[0].cuda()
            total_seen += batch_seen
            total_similar += batch_similar
            total_unseen += batch_unseen
            total_source += batch_source

            n_token = int((labels.data != 0).data.sum())

            predicts = net(inputs, positions, token_types, masks)
            loss = lossFunc(predicts, labels, n_token).sum()

            Q.append(float(loss))
            if len(Q) > 200:
                Q.pop(0)
            loss_avg = sum(Q) / len(Q)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            LOG.log(
                'Epoch %2d, Batch %6d, Loss %9.6f, Average Loss %9.6f, Time %9.6f'
                % (epoch_idx + 1, batch_idx + 1, loss, loss_avg,
                   time.time() - start_time))

            # Checkpoints
            idx = epoch_idx * trainSet.__len__() + batch_idx + 1
            if (idx >= options['training']['checkingPoints']['checkMin']) and (
                    idx % options['training']['checkingPoints']['checkFreq']
                    == 0):
                if config.do_eval:
                    vloss = 0
                    total_tokens = 0
                    for bid, batch_data in enumerate(validLoader):
                        inputs, positions, token_types, labels, masks, batch_seen, batch_similar, batch_unseen, batch_source = batch_data

                        inputs = inputs[0].cuda()
                        positions = positions[0].cuda()
                        token_types = token_types[0].cuda()
                        labels = labels[0].cuda()
                        masks = masks[0].cuda()

                        n_token = int((labels.data != config.PAD).data.sum())

                        with torch.no_grad():
                            net.eval()
                            predicts = net(inputs, positions, token_types,
                                           masks)
                            vloss += float(lossFunc(predicts, labels).sum())

                        total_tokens += n_token

                    vloss /= total_tokens
                    is_best = vloss < best_vloss
                    best_vloss = min(vloss, best_vloss)
                    LOG.log(
                        'CheckPoint: Validation Loss %11.8f, Best Loss %11.8f'
                        % (vloss, best_vloss))

                    if is_best:
                        LOG.log('Best Model Updated')
                        save_check_point(
                            {
                                'epoch': epoch_idx + 1,
                                'batch': batch_idx + 1,
                                'options': options,
                                'config': config,
                                'state_dict': net.state_dict(),
                                'best_vloss': best_vloss
                            },
                            is_best,
                            path=config.save_path,
                            fileName='latest.pth.tar')
                        counter = 0
                    else:
                        counter += options['training']['checkingPoints'][
                            'checkFreq']
                        if counter >= options['training']['stopConditions'][
                                'rateReduce_bound']:
                            counter = 0
                            for param_group in optimizer.param_groups:
                                lr_ = param_group['lr']
                                param_group['lr'] *= 0.55
                                _lr = param_group['lr']
                            LOG.log(
                                'Reduce Learning Rate from %11.8f to %11.8f' %
                                (lr_, _lr))
                        LOG.log('Current Counter = %d' % (counter))

                else:
                    save_check_point(
                        {
                            'epoch': epoch_idx + 1,
                            'batch': batch_idx + 1,
                            'options': options,
                            'config': config,
                            'state_dict': net.state_dict(),
                            'best_vloss': 1e99
                        },
                        False,
                        path=config.save_path,
                        fileName='checkpoint_Epoch' + str(epoch_idx + 1) +
                        '_Batch' + str(batch_idx + 1) + '.pth.tar')
                    LOG.log('CheckPoint Saved!')

        if options['training']['checkingPoints']['everyEpoch']:
            save_check_point(
                {
                    'epoch': epoch_idx + 1,
                    'batch': batch_idx + 1,
                    'options': options,
                    'config': config,
                    'state_dict': net.state_dict(),
                    'best_vloss': 1e99
                },
                False,
                path=config.save_path,
                fileName='checkpoint_Epoch' + str(epoch_idx + 1) + '.pth.tar')

        LOG.log('Epoch Finished.')
        LOG.log(
            'Total Seen: %d, Total Unseen: %d, Total Similar: %d, Total Source: %d.'
            % (total_seen, total_unseen, total_similar, total_source))
        gc.collect()
Example #2
0
def test(config):
    Best_Model = torch.load(config.test_model)
    Tokenizer = BertTokenizer.from_pretrained(config.model)

    f_in = open(config.inputFile, 'r')

    net = BertForMaskedLM.from_pretrained(config.model)

    # When loading from a model not trained from DataParallel
    #net.load_state_dict(Best_Model['state_dict'])
    #net.eval()

    if torch.cuda.is_available():
        net = net.cuda(0)
        if config.dataParallel:
            net = DataParallelModel(net)

    # When loading from a model trained from DataParallel
    net.load_state_dict(Best_Model['state_dict'])
    net.eval()

    mySearcher = Searcher(net, config)

    f_top1 = open('summary' + config.suffix + '.txt', 'w', encoding='utf-8')
    f_topK = open('summary' + config.suffix + '.txt.' +
                  str(config.answer_size),
                  'w',
                  encoding='utf-8')

    ed = '\n------------------------\n'

    for idx, line in enumerate(f_in):
        source_ = line.strip().split()
        source = Tokenizer.tokenize(line.strip())
        mapping = mapping_tokenize(source_, source)

        source = Tokenizer.convert_tokens_to_ids(source)

        print(idx)
        print(detokenize(translate(source, Tokenizer), mapping), end=ed)

        l_pred = mySearcher.length_Predict(source)
        Answers = mySearcher.search(source)
        baseline = sum(Answers[0][0])

        if config.reranking_method == 'none':
            Answers = sorted(Answers, key=lambda x: sum(x[0]))
        elif config.reranking_method == 'length_norm':
            Answers = sorted(Answers, key=lambda x: length_norm(x[0]))
        elif config.reranking_method == 'bounded_word_reward':
            Answers = sorted(
                Answers,
                key=lambda x: bounded_word_reward(x[0], config.reward, l_pred))
        elif config.reranking_method == 'bounded_adaptive_reward':
            Answers = sorted(
                Answers,
                key=lambda x: bounded_adaptive_reward(x[0], x[2], l_pred))

        texts = [
            detokenize(translate(Answers[k][1], Tokenizer), mapping)
            for k in range(len(Answers))
        ]

        if baseline != sum(Answers[0][0]):
            print('Reranked!')

        print(texts[0], end=ed)
        print(texts[0], file=f_top1)
        print(len(texts), file=f_topK)
        for i in range(len(texts)):
            print(Answers[i][0], file=f_topK)
            print(texts[i], file=f_topK)

    f_top1.close()
    f_topK.close()
Example #3
0
class Trainer:
    """
    trainer class
    """
    def __init__(self, cfg: Namespace, data: Dataset):
        """
        Args:
            cfg:  configuration
            data:  train dataset
        """
        self.cfg = cfg
        self.train, self.valid = data.split(0.8)
        RATING_FIELD.build_vocab(self.train)

        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')  # pylint: disable=no-member
        self.batch_size = cfg.batch_size
        if torch.cuda.is_available():
            self.batch_size *= torch.cuda.device_count()

        self.trn_itr = BucketIterator(
            self.train,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=True,
            train=True,
            sort_within_batch=True,
            sort_key=lambda exam: -len(exam.comment_text))
        self.vld_itr = BucketIterator(
            self.valid,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=False,
            train=False,
            sort_within_batch=True,
            sort_key=lambda exam: -len(exam.comment_text))
        self.log_step = 1000
        if len(self.vld_itr) < 100:
            self.log_step = 10
        elif len(self.vld_itr) < 1000:
            self.log_step = 100

        bert_path = cfg.bert_path if cfg.bert_path else 'bert-base-cased'
        self.model = BertForSequenceClassification.from_pretrained(
            bert_path, num_labels=2)
        pos_weight = (
            len([exam for exam in self.train.examples if exam.target < 0.5]) /
            len([exam for exam in self.train.examples if exam.target >= 0.5]))
        pos_wgt_tensor = torch.tensor([1.0, pos_weight], device=self.device)  # pylint: disable=not-callable
        self.criterion = nn.CrossEntropyLoss(weight=pos_wgt_tensor)
        if torch.cuda.is_available():
            self.model = DataParallelModel(self.model.cuda())
            self.criterion = DataParallelCriterion(self.criterion)
        self.optimizer = optim.Adam(self.model.parameters(), cfg.learning_rate)

    def run(self):
        """
        do train
        """
        max_f_score = -9e10
        max_epoch = -1
        for epoch in range(self.cfg.epoch):
            train_loss = self._train_epoch(epoch)
            metrics = self._evaluate(epoch)
            max_f_score_str = f' < {max_f_score:.2f}'
            if metrics['f_score'] > max_f_score:
                max_f_score_str = ' is max'
                max_f_score = metrics['f_score']
                max_epoch = epoch
                torch.save(self.model.state_dict(), self.cfg.model_path)
            logging.info('EPOCH[%d]: train loss: %.6f, valid loss: %.6f, acc: %.2f,' \
                         ' F: %.2f%s', epoch, train_loss, metrics['loss'],
                         metrics['accuracy'], metrics['f_score'], max_f_score_str)
            if (epoch - max_epoch) >= self.cfg.patience:
                logging.info('early stopping...')
                break
        logging.info('epoch: %d, f-score: %.2f', max_epoch, max_f_score)

    def _train_epoch(self, epoch: int) -> float:
        """
        train single epoch
        Args:
            epoch:  epoch number
        Returns:
            average loss
        """
        self.model.train()
        progress = tqdm(self.trn_itr,
                        f'EPOCH[{epoch}]',
                        mininterval=1,
                        ncols=100)
        losses = []
        for step, batch in enumerate(progress, start=1):
            outputs = self.model(batch.comment_text)
            # output of model wrapped with DataParallelModel is a list of outputs from each GPU
            # make input of DataParallelCriterion as a list of tuples
            if isinstance(self.model, DataParallelModel):
                loss = self.criterion([(output, ) for output in outputs],
                                      batch.target)
            else:
                loss = self.criterion(outputs, batch.target)
            losses.append(loss.item())
            if step % self.log_step == 0:
                avg_loss = sum(losses) / len(losses)
                progress.set_description(f'EPOCH[{epoch}] ({avg_loss:.6f})')
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
        return sum(losses) / len(losses)

    def _evaluate(self, epoch: int) -> Dict[str, float]:
        """
        evaluate on validation data
        Args:
            epoch:  epoch number
        Returns:
            metrics
        """
        self.model.eval()
        progress = tqdm(self.vld_itr,
                        f' EVAL[{epoch}]',
                        mininterval=1,
                        ncols=100)
        losses = []
        preds = []
        golds = []
        for step, batch in enumerate(progress, start=1):
            with torch.no_grad():
                outputs = self.model(batch.comment_text)
                if isinstance(self.model, DataParallelModel):
                    loss = self.criterion([(output, ) for output in outputs],
                                          batch.target)
                    for output in outputs:
                        preds.extend([(0 if o[0] < o[1] else 1)
                                      for o in output])
                else:
                    loss = self.criterion(outputs, batch.target)
                    preds.extend([(0 if output[0] < output[1] else 1)
                                  for output in outputs])
                losses.append(loss.item())
                golds.extend([gold.item() for gold in batch.target])
                if step % self.log_step == 0:
                    avg_loss = sum(losses) / len(losses)
                    progress.set_description(
                        f' EVAL[{epoch}] ({avg_loss:.6f})')
        metrics = self._get_metrics(preds, golds)
        metrics['loss'] = sum(losses) / len(losses)
        return metrics

    @classmethod
    def _get_metrics(cls, preds: List[float],
                     golds: List[float]) -> Dict[str, float]:
        """
        get metric values
        Args:
            preds:  predictions
            golds:  gold standards
        Returns:
            metric
        """
        assert len(preds) == len(golds)
        true_pos = 0
        false_pos = 0
        false_neg = 0
        true_neg = 0
        for pred, gold in zip(preds, golds):
            if pred >= 0.5:
                if gold >= 0.5:
                    true_pos += 1
                else:
                    false_pos += 1
            else:
                if gold >= 0.5:
                    false_neg += 1
                else:
                    true_neg += 1
        accuracy = (true_pos + true_neg) / (true_pos + false_pos + false_neg +
                                            true_neg)
        precision = 0.0
        if (true_pos + false_pos) > 0:
            precision = true_pos / (true_pos + false_pos)
        recall = 0.0
        if (true_pos + false_neg) > 0:
            recall = true_pos / (true_pos + false_neg)
        f_score = 0.0
        if (precision + recall) > 0.0:
            f_score = 2.0 * precision * recall / (precision + recall)
        return {
            'accuracy': 100.0 * accuracy,
            'precision': 100.0 * precision,
            'recall': 100.0 * recall,
            'f_score': 100.0 * f_score,
        }