def run(self):
        best_err1 = 100.
        best_epoch = 0

        logger.info('==> creating model "{}"'.format(args.model_name))
        model = Util.getModel(**vars(args))

        model = model.to(DEVICE)
        # 大部分情况下,设置这个flag可以让内置的cuDNN的auto - tuner自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。
        cudnn.benchmark = True
        # define loss function (criterion) and pptimizer
        # criterion = nn.CrossEntropyLoss().to(DEVICE)
        # 标签平滑
        criterion = LabelSmoothingLoss(classes=self.args.num_classes, smoothing=0.2)

        # define optimizer
        optimizer = Util.getOptimizer(model=model, args=self.args)

        trainer = Trainer(dataset=self.dataset, criterion=criterion, optimizer=optimizer, args=self.args, logger=logger)
        logger.info('train: {} test: {}'.format(self.dataset.get_train_length(), self.dataset.get_validation_length()))
        for epoch in range(0, self.args.EPOCHS):
            # train for one epoch
            model = trainer.train(model=model, epoch=epoch)

            # evaluate on validation set
            model, val_loss, val_err1 = trainer.test(model=model, epoch=epoch)

            # remember best err@1 and save checkpoint
            is_best = val_err1 < best_err1
            if is_best:
                best_err1 = val_err1
                best_epoch = epoch
                logger.info('Best var_err1 {}'.format(best_err1))
            Util.save_checkpoint(model.state_dict(), is_best, args.output_models_dir)
            if not is_best and epoch - best_epoch >= args.patience > 0:
                break

        logger.info('Best val_err1: {:.4f} at epoch {}'.format(best_err1, best_epoch))
예제 #2
0
from models import ConvModel


dataset = KPIDataset(
    '../data/train_preprocessed.csv',
    seq_length=1001,
    step_width=1
)

model = ConvModel(1001)

args = {
    "lr": 0.5e-4,
    "betas": (0.9, 0.999),
    "eps": 1e-8,
    "weight_decay": 0.0
}

trainer = Trainer(
    model,
    dataset,
    batch_size=512,
    epochs=100,
    log_nth=800,
    validation_size=0.2,
    optim_args=args,
    loss_func=CrossEntropyLoss()
)

trainer.train()