Esempio n. 1
0
class Trainer:
    """
    trainer class
    """
    def __init__(self, cfg: Namespace, data: Dataset):
        """
        Args:
            cfg:  configuration
            data:  train dataset
        """
        self.cfg = cfg
        self.train, self.valid = data.split(0.8)
        RATING_FIELD.build_vocab(self.train)

        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')  # pylint: disable=no-member
        self.batch_size = cfg.batch_size
        if torch.cuda.is_available():
            self.batch_size *= torch.cuda.device_count()

        self.trn_itr = BucketIterator(
            self.train,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=True,
            train=True,
            sort_within_batch=True,
            sort_key=lambda exam: -len(exam.comment_text))
        self.vld_itr = BucketIterator(
            self.valid,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=False,
            train=False,
            sort_within_batch=True,
            sort_key=lambda exam: -len(exam.comment_text))
        self.log_step = 1000
        if len(self.vld_itr) < 100:
            self.log_step = 10
        elif len(self.vld_itr) < 1000:
            self.log_step = 100

        bert_path = cfg.bert_path if cfg.bert_path else 'bert-base-cased'
        self.model = BertForSequenceClassification.from_pretrained(
            bert_path, num_labels=2)
        pos_weight = (
            len([exam for exam in self.train.examples if exam.target < 0.5]) /
            len([exam for exam in self.train.examples if exam.target >= 0.5]))
        pos_wgt_tensor = torch.tensor([1.0, pos_weight], device=self.device)  # pylint: disable=not-callable
        self.criterion = nn.CrossEntropyLoss(weight=pos_wgt_tensor)
        if torch.cuda.is_available():
            self.model = DataParallelModel(self.model.cuda())
            self.criterion = DataParallelCriterion(self.criterion)
        self.optimizer = optim.Adam(self.model.parameters(), cfg.learning_rate)

    def run(self):
        """
        do train
        """
        max_f_score = -9e10
        max_epoch = -1
        for epoch in range(self.cfg.epoch):
            train_loss = self._train_epoch(epoch)
            metrics = self._evaluate(epoch)
            max_f_score_str = f' < {max_f_score:.2f}'
            if metrics['f_score'] > max_f_score:
                max_f_score_str = ' is max'
                max_f_score = metrics['f_score']
                max_epoch = epoch
                torch.save(self.model.state_dict(), self.cfg.model_path)
            logging.info('EPOCH[%d]: train loss: %.6f, valid loss: %.6f, acc: %.2f,' \
                         ' F: %.2f%s', epoch, train_loss, metrics['loss'],
                         metrics['accuracy'], metrics['f_score'], max_f_score_str)
            if (epoch - max_epoch) >= self.cfg.patience:
                logging.info('early stopping...')
                break
        logging.info('epoch: %d, f-score: %.2f', max_epoch, max_f_score)

    def _train_epoch(self, epoch: int) -> float:
        """
        train single epoch
        Args:
            epoch:  epoch number
        Returns:
            average loss
        """
        self.model.train()
        progress = tqdm(self.trn_itr,
                        f'EPOCH[{epoch}]',
                        mininterval=1,
                        ncols=100)
        losses = []
        for step, batch in enumerate(progress, start=1):
            outputs = self.model(batch.comment_text)
            # output of model wrapped with DataParallelModel is a list of outputs from each GPU
            # make input of DataParallelCriterion as a list of tuples
            if isinstance(self.model, DataParallelModel):
                loss = self.criterion([(output, ) for output in outputs],
                                      batch.target)
            else:
                loss = self.criterion(outputs, batch.target)
            losses.append(loss.item())
            if step % self.log_step == 0:
                avg_loss = sum(losses) / len(losses)
                progress.set_description(f'EPOCH[{epoch}] ({avg_loss:.6f})')
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
        return sum(losses) / len(losses)

    def _evaluate(self, epoch: int) -> Dict[str, float]:
        """
        evaluate on validation data
        Args:
            epoch:  epoch number
        Returns:
            metrics
        """
        self.model.eval()
        progress = tqdm(self.vld_itr,
                        f' EVAL[{epoch}]',
                        mininterval=1,
                        ncols=100)
        losses = []
        preds = []
        golds = []
        for step, batch in enumerate(progress, start=1):
            with torch.no_grad():
                outputs = self.model(batch.comment_text)
                if isinstance(self.model, DataParallelModel):
                    loss = self.criterion([(output, ) for output in outputs],
                                          batch.target)
                    for output in outputs:
                        preds.extend([(0 if o[0] < o[1] else 1)
                                      for o in output])
                else:
                    loss = self.criterion(outputs, batch.target)
                    preds.extend([(0 if output[0] < output[1] else 1)
                                  for output in outputs])
                losses.append(loss.item())
                golds.extend([gold.item() for gold in batch.target])
                if step % self.log_step == 0:
                    avg_loss = sum(losses) / len(losses)
                    progress.set_description(
                        f' EVAL[{epoch}] ({avg_loss:.6f})')
        metrics = self._get_metrics(preds, golds)
        metrics['loss'] = sum(losses) / len(losses)
        return metrics

    @classmethod
    def _get_metrics(cls, preds: List[float],
                     golds: List[float]) -> Dict[str, float]:
        """
        get metric values
        Args:
            preds:  predictions
            golds:  gold standards
        Returns:
            metric
        """
        assert len(preds) == len(golds)
        true_pos = 0
        false_pos = 0
        false_neg = 0
        true_neg = 0
        for pred, gold in zip(preds, golds):
            if pred >= 0.5:
                if gold >= 0.5:
                    true_pos += 1
                else:
                    false_pos += 1
            else:
                if gold >= 0.5:
                    false_neg += 1
                else:
                    true_neg += 1
        accuracy = (true_pos + true_neg) / (true_pos + false_pos + false_neg +
                                            true_neg)
        precision = 0.0
        if (true_pos + false_pos) > 0:
            precision = true_pos / (true_pos + false_pos)
        recall = 0.0
        if (true_pos + false_neg) > 0:
            recall = true_pos / (true_pos + false_neg)
        f_score = 0.0
        if (precision + recall) > 0.0:
            f_score = 2.0 * precision * recall / (precision + recall)
        return {
            'accuracy': 100.0 * accuracy,
            'precision': 100.0 * precision,
            'recall': 100.0 * recall,
            'f_score': 100.0 * f_score,
        }
Esempio n. 2
0
def main_tr(args, crossVal):
    dataLoad = ld.LoadData(args.data_dir, args.classes)
    data = dataLoad.processData(crossVal, args.data_name)

    # load the model
    model = net.MiniSeg(args.classes, aux=True)
    if not osp.isdir(osp.join(args.savedir + '_mod' + str(args.max_epochs))):
        os.mkdir(args.savedir + '_mod' + str(args.max_epochs))
    if not osp.isdir(
            osp.join(args.savedir + '_mod' + str(args.max_epochs),
                     args.data_name)):
        os.mkdir(
            osp.join(args.savedir + '_mod' + str(args.max_epochs),
                     args.data_name))
    saveDir = args.savedir + '_mod' + str(
        args.max_epochs) + '/' + args.data_name + '/' + args.model_name
    # create the directory if not exist
    if not osp.exists(saveDir):
        os.mkdir(saveDir)

    if args.gpu and torch.cuda.device_count() > 1:
        #model = torch.nn.DataParallel(model)
        model = DataParallelModel(model)
    if args.gpu:
        model = model.cuda()

    total_paramters = sum([np.prod(p.size()) for p in model.parameters()])
    print('Total network parameters: ' + str(total_paramters))

    # define optimization criteria
    weight = torch.from_numpy(
        data['classWeights'])  # convert the numpy array to torch
    if args.gpu:
        weight = weight.cuda()

    criteria = CrossEntropyLoss2d(weight, args.ignore_label)  #weight
    if args.gpu and torch.cuda.device_count() > 1:
        criteria = DataParallelCriterion(criteria)
    if args.gpu:
        criteria = criteria.cuda()

    # compose the data with transforms
    trainDataset_main = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(args.width, args.height),
        myTransforms.RandomCropResize(int(32. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])
    trainDataset_scale1 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 1.5), int(args.height * 1.5)),
        myTransforms.RandomCropResize(int(100. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])

    trainDataset_scale2 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 1.25), int(args.height * 1.25)),
        myTransforms.RandomCropResize(int(100. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])
    trainDataset_scale3 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 0.75), int(args.height * 0.75)),
        myTransforms.RandomCropResize(int(32. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])

    valDataset = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(args.width, args.height),
        myTransforms.ToTensor()
    ])

    # since we training from scratch, we create data loaders at different scales
    # so that we can generate more augmented data and prevent the network from overfitting
    trainLoader = torch.utils.data.DataLoader(myDataLoader.Dataset(
        data['trainIm'], data['trainAnnot'], transform=trainDataset_main),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              drop_last=True)

    trainLoader_scale1 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale1),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)

    trainLoader_scale2 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale2),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)
    trainLoader_scale3 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale3),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)

    valLoader = torch.utils.data.DataLoader(myDataLoader.Dataset(
        data['valIm'], data['valAnnot'], transform=valDataset),
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers,
                                            pin_memory=True)
    max_batches = len(trainLoader) + len(trainLoader_scale1) + len(
        trainLoader_scale2) + len(trainLoader_scale3)

    if args.gpu:
        cudnn.benchmark = True

    start_epoch = 0

    if args.pretrained is not None:
        state_dict = torch.load(args.pretrained)
        new_keys = []
        new_values = []
        for idx, key in enumerate(state_dict.keys()):
            if 'pred' not in key:
                new_keys.append(key)
                new_values.append(list(state_dict.values())[idx])
        new_dict = OrderedDict(list(zip(new_keys, new_values)))
        model.load_state_dict(new_dict, strict=False)
        print('pretrained model loaded')

    if args.resume is not None:
        if osp.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            args.lr = checkpoint['lr']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    log_file = osp.join(saveDir, 'trainValLog_' + args.model_name + '.txt')
    if osp.isfile(log_file):
        logger = open(log_file, 'a')
    else:
        logger = open(log_file, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write("\n%s\t%s\t\t%s\t%s\t%s\t%s\tlr" %
                     ('CrossVal', 'Epoch', 'Loss(Tr)', 'Loss(val)',
                      'mIOU (tr)', 'mIOU (val)'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=1e-4)
    maxmIOU = 0
    maxEpoch = 0
    print(args.model_name + '-CrossVal: ' + str(crossVal + 1))
    for epoch in range(start_epoch, args.max_epochs):
        # train for one epoch
        cur_iter = 0

        train(args, trainLoader_scale1, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale1)
        train(args, trainLoader_scale2, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale2)
        train(args, trainLoader_scale3, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale3)
        lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr, lr = \
                train(args, trainLoader, model, criteria, optimizer, epoch, max_batches, cur_iter)

        # evaluate on validation set
        lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = \
                val(args, valLoader, model, criteria)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lossTr': lossTr,
                'lossVal': lossVal,
                'iouTr': mIOU_tr,
                'iouVal': mIOU_val,
                'lr': lr
            },
            osp.join(
                saveDir, 'checkpoint_' + args.model_name + '_crossVal' +
                str(crossVal + 1) + '.pth.tar'))

        # save the model also
        model_file_name = osp.join(
            saveDir, 'model_' + args.model_name + '_crossVal' +
            str(crossVal + 1) + '_' + str(epoch + 1) + '.pth')
        torch.save(model.state_dict(), model_file_name)

        logger.write(
            "\n%d\t\t%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" %
            (crossVal + 1, epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val, lr))
        logger.flush()
        print("\nEpoch No. %d:\tTrain Loss = %.4f\tVal Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\n" \
                % (epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val))

        if mIOU_val >= maxmIOU:
            maxmIOU = mIOU_val
            maxEpoch = epoch + 1
        torch.cuda.empty_cache()
    logger.flush()
    logger.close()
    return maxEpoch, maxmIOU