コード例 #1
0
ファイル: train.py プロジェクト: Fabriceli/auto-car
class Train(object):
    def __init__(self, args):
        self.args = args
        # 初始化tensorboard summary
        self.summary = TensorboardSummary(directory=args.save_path)
        self.writer = self.summary.create_summary()
        # 初始化dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_dataset = Apolloscapes('train_dataset.csv', '/home/aistudio/data/data1919/Image_Data', '/home/aistudio/data/data1919/Gray_Label',
                                     args.crop_size, type='train')

        self.dataloader = DataLoader(self.train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)

        self.val_dataset = Apolloscapes('val_dataset.csv', '/home/aistudio/data/data1919/Image_Data', '/home/aistudio/data/data1919/Gray_Label',
                                          args.crop_size, type='val')

        self.val_loader = DataLoader(self.val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, **kwargs)

        # 初始化model
        self.model = DeeplabV3Plus(backbone=args.backbone,
                              output_stride=args.out_stride,
                              batch_norm=args.batch_norm,
                              num_classes=args.num_classes,
                              pretrain=True)
        # 初始化优化器
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         momentum=args.momentum,
                                         nesterov=args.nesterov,
                                         weight_decay=args.weight_decay,
                                         lr=args.lr)

        # 定义损失函数
        self.loss = CELoss(num_class=args.num_classes, cuda=args.cuda)

        # 定义验证器
        self.evaluator = Evaluator(args.num_classes)

        # 定义学习率
        self.scheduler = LR_Scheduler('poly', args.lr, args.epochs, len(self.dataloader))

        # 使用cuda
        if args.cuda:
            self.model = self.model.cuda(device=args.gpus[0])
            self.model = torch.nn.DataParallel(self.model, device_ids=args.gpus)

    def train(self, epoch):
        loss = 0.0
        self.model.train()
        data = tqdm(self.dataloader)
        length = len(self.dataloader)
        for i, sample in enumerate(data):
            image, label = sample['image'], sample['label']
            if self.args.cuda:
                image = image.cuda()
                label = label.cuda()
            self.scheduler(self.optimizer, i, epoch, 0.0)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss_function = self.loss(output, label)
            loss_function.backward()
            self.optimizer.step()
            loss += loss_function.item()

            data.set_description('Train loss: %.3f' % (loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss_function.item(), i + length * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (length // 10) == 0:
                global_step = i + length * epoch
                self.summary.visualize_image(self.writer, image, label, output, global_step)
            self.writer.add_scalar('train/total_loss_epoch', loss, epoch)
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print('Loss: %.3f' % loss)
        torch.save({'state_dict': self.model.state_dict()},
                           os.path.join(os.getcwd(), self.args.save_path, "laneNet{}.pth.tar".format(epoch)))

    def val(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.loss(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
コード例 #2
0
class Trainer(object):
    def __init__(self, settings: dict, settings_to_log: list):
        self.settings = settings
        self.settings_to_log = settings_to_log

        self.threshold = self.settings['threshold']
        self.start_epoch = self.settings['start_epoch']
        self.dataset = self.settings['dataset']
        self.batch_size = self.settings['batch_size']
        self.workers = self.settings['workers']
        self.cuda = self.settings['cuda']
        self.fp16 = self.settings['fp16']
        self.epochs = self.settings['epochs']
        self.ignore_index = self.settings['ignore_index']
        self.loss_reduction = self.settings['loss_reduction']

        # -------------------- Define Data loader ------------------------------
        self.loaders, self.nclass, self.plotter = make_data_loader(settings)
        self.train_loader, self.val_loader, self.test_loader = [self.loaders[key] for key in ['train', 'val', 'test']]

        # -------------------- Define model ------------------------------------
        self.model = get_model(self.settings)

        # -------------------- Define optimizer and its options ----------------
        self.optimizer = define_optimizer(self.model, self.settings['optimizer'], self.settings['optimizer_params'])
        if self.settings['lr_scheduler']:
            self.lr_scheduler = LRScheduler(self.settings['lr_scheduler'], self.optimizer, self.batch_size)

        # -------------------- Define loss -------------------------------------
        input_size = (self.batch_size, self.nclass, *self.settings['target_size'])
        self.criterion = CustomLoss(input_size=input_size, ignore_index=self.ignore_index, reduction=self.loss_reduction)

        self.evaluator = Evaluator(metrics=self.settings['metrics'], num_class=self.nclass, threshold=self.settings['threshold'])

        self.logger = MainLogger(loggers=self.settings['loggers'], settings=settings, settings_to_log=settings_to_log)
        if self.settings['resume']:
            self.resume_checkpoint(self.settings['resume'])

        self.metric_to_watch = 0.0

    def activation(self, output):
        if self.nclass == 1:
            output = torch.sigmoid(output)
        else:
            output = torch.softmax(output, dim=1)
        return output

    def prepare_inputs(self, *inputs):
        if self.settings['cuda']:
            inputs = [i.cuda() for i in inputs]
        if self.settings['fp16']:
            inputs = [i.half() for i in inputs]
        return inputs

    def training(self, epoch: int):
        """
        Training loop for a certain epoch
        :param epoch: epoch id
        :return:
        """
        self.evaluator.reset()
        self.model.train()
        tbar = tqdm(self.train_loader, desc='train', file=sys.stdout)
        train_loss = 0.0
        output = {}
        for i, sample in enumerate(tbar):
            img, target = self.prepare_inputs(sample['image'], sample['label'])
            img, target, perm_target, gamma = random_joint_mix(img, target, self.settings['CutMix'], self.settings['MixUp'], p=self.settings['MixP'])

            self.optimizer.zero_grad()
            output['pred'], output['pred8'], output['pred16'] = self.model(img)

            if self.settings['MixUp'] or self.settings['CutMix']:
                loss = mix_criterion(self.criterion.train_loss, output, tgt_a=target, tgt_b=perm_target, gamma=gamma)
            else:
                loss = self.criterion.train_loss(**output, target=target)
            loss.backward()

            self.optimizer.step()
            train_loss += loss.item()

            if self.settings['lr_scheduler']:
                self.lr_scheduler(i, epoch, self.metric_to_watch)

            out = self.activation(output['pred'])
            self.evaluator.add_batch(out, target)
            tbar.set_description('Train loss: %.4f, Epoch: %d' % (train_loss / float(i + 1), epoch))

            self.logger.log_metric(metric_tuple=('TRAIN_LOSS', (train_loss / float(i + 1))))
        _ = self.evaluator.eval_metrics(reduction=self.settings['evaluator_reduction'], show=True)

    def validation(self, epoch: int):
        """
        Validation loop for a certain epoch
        :param epoch: epoch id
        :return:
        """
        self.evaluator.reset()
        self.model.eval()
        if self.settings['validation_only']:
            loader = self.loaders[self.settings['validation_only']]
        else:
            loader = self.val_loader
        tbar = tqdm(loader, desc='valid', file=sys.stdout)
        test_loss = 0.0
        with torch.no_grad():
            for i, sample in enumerate(tbar):
                img, target = self.prepare_inputs(sample['image'], sample['label'])

                output = self.model(img)

                loss = self.criterion.val_loss(pred=output, target=target)
                test_loss += loss.item()

                output = self.activation(output)
                self.evaluator.add_batch(output, target)
                tbar.set_description('Validation loss: %.3f, Epoch: %d' % (test_loss / (i + 1), epoch))

                if self.settings['log_artifacts']:
                    self.log_artifacts(epoch=epoch, sample=sample, output=output)

                self.logger.log_metric(metric_tuple=('VAL_LOSS', test_loss / (i + 1)))
        metrics_dict = self.evaluator.eval_metrics(reduction=self.settings['evaluator_reduction'], show=True)
        metrics_dict['val_loss'] = test_loss / (i + 1)
        self.metric_to_watch = metrics_dict[self.settings['metric_to_watch']].mean()
        if not self.settings['validation_only']:
            self.save_checkpoint(epoch=epoch, metrics_dict=metrics_dict)

    def save_checkpoint(self, epoch, metrics_dict):
        state = {
            'epoch': epoch + 1,
            'state_dict': self.model.module.state_dict() if self.cuda else self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'metrics': metrics_dict,
            'scheduler': self.lr_scheduler.state_dict() if self.settings['lr_scheduler'] else None,
        }
        self.logger.log_metrics(self.settings['metrics'], metrics_dict, epoch=epoch)
        self.logger.log_checkpoint(state, key_metric=self.metric_to_watch, filename=self.settings['check_suffix'])

    def log_artifacts(self, sample, output, epoch):
        last_epoch = epoch == (self.settings['epochs'] - 1)
        if epoch % self.settings['log_dilate'] == 0 or last_epoch:
            sample['image'] = denormalize_image(sample['image'], **self.settings['normalize_params'])
            image, target, output = tensors_to_numpy(sample['image'], sample['label'], output)
            for ind, value in enumerate(sample['id']):
                if value in self.settings['inputs_to_watch']:
                    fig = self.plotter(image[ind], output[ind], target[ind],
                                       alpha=0.4, threshold=self.threshold, show=self.settings['show_results'])
                    self.logger.log_artifact(artifact=fig, epoch=epoch, name=value.replace('_leftImg8bit', ''))
                    plt.close()

    def resume_checkpoint(self, resume):
        if not os.path.isfile(resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(resume))
        checkpoint = torch.load(resume)
        self.start_epoch = checkpoint['epoch']
        if self.cuda:
            self.model.module.load_state_dict(checkpoint['state_dict'], strict=True)
        else:
            self.model.load_state_dict(checkpoint['state_dict'], strict=True)
        if not self.settings['fine_tuning']:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            if checkpoint['scheduler']:
                self.lr_scheduler.load_state_dict(checkpoint['scheduler'])
        self.metric_to_watch = checkpoint['best_pred']
        print("=> loaded checkpoint '{}' (epoch: {}, best_metric: {:.4f})"
              .format(resume, checkpoint['epoch'], self.metric_to_watch))

    def close(self):
        fig = plot_confusion_matrix(self.evaluator.confusion_matrix, normalize=True, title=None, cmap=plt.cm.Blues, show=False)
        self.logger.log_artifact(fig, epoch=-1, name='confusion_matrix.png')
        self.logger.close()