class Train(object): def __init__(self, args): self.args = args # 初始化tensorboard summary self.summary = TensorboardSummary(directory=args.save_path) self.writer = self.summary.create_summary() # 初始化dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.train_dataset = Apolloscapes('train_dataset.csv', '/home/aistudio/data/data1919/Image_Data', '/home/aistudio/data/data1919/Gray_Label', args.crop_size, type='train') self.dataloader = DataLoader(self.train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) self.val_dataset = Apolloscapes('val_dataset.csv', '/home/aistudio/data/data1919/Image_Data', '/home/aistudio/data/data1919/Gray_Label', args.crop_size, type='val') self.val_loader = DataLoader(self.val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, **kwargs) # 初始化model self.model = DeeplabV3Plus(backbone=args.backbone, output_stride=args.out_stride, batch_norm=args.batch_norm, num_classes=args.num_classes, pretrain=True) # 初始化优化器 self.optimizer = torch.optim.SGD(self.model.parameters(), momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay, lr=args.lr) # 定义损失函数 self.loss = CELoss(num_class=args.num_classes, cuda=args.cuda) # 定义验证器 self.evaluator = Evaluator(args.num_classes) # 定义学习率 self.scheduler = LR_Scheduler('poly', args.lr, args.epochs, len(self.dataloader)) # 使用cuda if args.cuda: self.model = self.model.cuda(device=args.gpus[0]) self.model = torch.nn.DataParallel(self.model, device_ids=args.gpus) def train(self, epoch): loss = 0.0 self.model.train() data = tqdm(self.dataloader) length = len(self.dataloader) for i, sample in enumerate(data): image, label = sample['image'], sample['label'] if self.args.cuda: image = image.cuda() label = label.cuda() self.scheduler(self.optimizer, i, epoch, 0.0) self.optimizer.zero_grad() output = self.model(image) loss_function = self.loss(output, label) loss_function.backward() self.optimizer.step() loss += loss_function.item() data.set_description('Train loss: %.3f' % (loss / (i + 1))) self.writer.add_scalar('train/total_loss_iter', loss_function.item(), i + length * epoch) # Show 10 * 3 inference results each epoch if i % (length // 10) == 0: global_step = i + length * epoch self.summary.visualize_image(self.writer, image, label, output, global_step) self.writer.add_scalar('train/total_loss_epoch', loss, epoch) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) print('Loss: %.3f' % loss) torch.save({'state_dict': self.model.state_dict()}, os.path.join(os.getcwd(), self.args.save_path, "laneNet{}.pth.tar".format(epoch))) def val(self, epoch): self.model.eval() self.evaluator.reset() tbar = tqdm(self.val_loader, desc='\r') test_loss = 0.0 for i, sample in enumerate(tbar): image, target = sample['image'], sample['label'] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output = self.model(image) loss = self.loss(output, target) test_loss += loss.item() tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(target, pred) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) self.writer.add_scalar('val/mIoU', mIoU, epoch) self.writer.add_scalar('val/Acc', Acc, epoch) self.writer.add_scalar('val/Acc_class', Acc_class, epoch) self.writer.add_scalar('val/fwIoU', FWIoU, epoch) print('Validation:') print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) print('Loss: %.3f' % test_loss)
class Trainer(object): def __init__(self, settings: dict, settings_to_log: list): self.settings = settings self.settings_to_log = settings_to_log self.threshold = self.settings['threshold'] self.start_epoch = self.settings['start_epoch'] self.dataset = self.settings['dataset'] self.batch_size = self.settings['batch_size'] self.workers = self.settings['workers'] self.cuda = self.settings['cuda'] self.fp16 = self.settings['fp16'] self.epochs = self.settings['epochs'] self.ignore_index = self.settings['ignore_index'] self.loss_reduction = self.settings['loss_reduction'] # -------------------- Define Data loader ------------------------------ self.loaders, self.nclass, self.plotter = make_data_loader(settings) self.train_loader, self.val_loader, self.test_loader = [self.loaders[key] for key in ['train', 'val', 'test']] # -------------------- Define model ------------------------------------ self.model = get_model(self.settings) # -------------------- Define optimizer and its options ---------------- self.optimizer = define_optimizer(self.model, self.settings['optimizer'], self.settings['optimizer_params']) if self.settings['lr_scheduler']: self.lr_scheduler = LRScheduler(self.settings['lr_scheduler'], self.optimizer, self.batch_size) # -------------------- Define loss ------------------------------------- input_size = (self.batch_size, self.nclass, *self.settings['target_size']) self.criterion = CustomLoss(input_size=input_size, ignore_index=self.ignore_index, reduction=self.loss_reduction) self.evaluator = Evaluator(metrics=self.settings['metrics'], num_class=self.nclass, threshold=self.settings['threshold']) self.logger = MainLogger(loggers=self.settings['loggers'], settings=settings, settings_to_log=settings_to_log) if self.settings['resume']: self.resume_checkpoint(self.settings['resume']) self.metric_to_watch = 0.0 def activation(self, output): if self.nclass == 1: output = torch.sigmoid(output) else: output = torch.softmax(output, dim=1) return output def prepare_inputs(self, *inputs): if self.settings['cuda']: inputs = [i.cuda() for i in inputs] if self.settings['fp16']: inputs = [i.half() for i in inputs] return inputs def training(self, epoch: int): """ Training loop for a certain epoch :param epoch: epoch id :return: """ self.evaluator.reset() self.model.train() tbar = tqdm(self.train_loader, desc='train', file=sys.stdout) train_loss = 0.0 output = {} for i, sample in enumerate(tbar): img, target = self.prepare_inputs(sample['image'], sample['label']) img, target, perm_target, gamma = random_joint_mix(img, target, self.settings['CutMix'], self.settings['MixUp'], p=self.settings['MixP']) self.optimizer.zero_grad() output['pred'], output['pred8'], output['pred16'] = self.model(img) if self.settings['MixUp'] or self.settings['CutMix']: loss = mix_criterion(self.criterion.train_loss, output, tgt_a=target, tgt_b=perm_target, gamma=gamma) else: loss = self.criterion.train_loss(**output, target=target) loss.backward() self.optimizer.step() train_loss += loss.item() if self.settings['lr_scheduler']: self.lr_scheduler(i, epoch, self.metric_to_watch) out = self.activation(output['pred']) self.evaluator.add_batch(out, target) tbar.set_description('Train loss: %.4f, Epoch: %d' % (train_loss / float(i + 1), epoch)) self.logger.log_metric(metric_tuple=('TRAIN_LOSS', (train_loss / float(i + 1)))) _ = self.evaluator.eval_metrics(reduction=self.settings['evaluator_reduction'], show=True) def validation(self, epoch: int): """ Validation loop for a certain epoch :param epoch: epoch id :return: """ self.evaluator.reset() self.model.eval() if self.settings['validation_only']: loader = self.loaders[self.settings['validation_only']] else: loader = self.val_loader tbar = tqdm(loader, desc='valid', file=sys.stdout) test_loss = 0.0 with torch.no_grad(): for i, sample in enumerate(tbar): img, target = self.prepare_inputs(sample['image'], sample['label']) output = self.model(img) loss = self.criterion.val_loss(pred=output, target=target) test_loss += loss.item() output = self.activation(output) self.evaluator.add_batch(output, target) tbar.set_description('Validation loss: %.3f, Epoch: %d' % (test_loss / (i + 1), epoch)) if self.settings['log_artifacts']: self.log_artifacts(epoch=epoch, sample=sample, output=output) self.logger.log_metric(metric_tuple=('VAL_LOSS', test_loss / (i + 1))) metrics_dict = self.evaluator.eval_metrics(reduction=self.settings['evaluator_reduction'], show=True) metrics_dict['val_loss'] = test_loss / (i + 1) self.metric_to_watch = metrics_dict[self.settings['metric_to_watch']].mean() if not self.settings['validation_only']: self.save_checkpoint(epoch=epoch, metrics_dict=metrics_dict) def save_checkpoint(self, epoch, metrics_dict): state = { 'epoch': epoch + 1, 'state_dict': self.model.module.state_dict() if self.cuda else self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'metrics': metrics_dict, 'scheduler': self.lr_scheduler.state_dict() if self.settings['lr_scheduler'] else None, } self.logger.log_metrics(self.settings['metrics'], metrics_dict, epoch=epoch) self.logger.log_checkpoint(state, key_metric=self.metric_to_watch, filename=self.settings['check_suffix']) def log_artifacts(self, sample, output, epoch): last_epoch = epoch == (self.settings['epochs'] - 1) if epoch % self.settings['log_dilate'] == 0 or last_epoch: sample['image'] = denormalize_image(sample['image'], **self.settings['normalize_params']) image, target, output = tensors_to_numpy(sample['image'], sample['label'], output) for ind, value in enumerate(sample['id']): if value in self.settings['inputs_to_watch']: fig = self.plotter(image[ind], output[ind], target[ind], alpha=0.4, threshold=self.threshold, show=self.settings['show_results']) self.logger.log_artifact(artifact=fig, epoch=epoch, name=value.replace('_leftImg8bit', '')) plt.close() def resume_checkpoint(self, resume): if not os.path.isfile(resume): raise RuntimeError("=> no checkpoint found at '{}'".format(resume)) checkpoint = torch.load(resume) self.start_epoch = checkpoint['epoch'] if self.cuda: self.model.module.load_state_dict(checkpoint['state_dict'], strict=True) else: self.model.load_state_dict(checkpoint['state_dict'], strict=True) if not self.settings['fine_tuning']: self.optimizer.load_state_dict(checkpoint['optimizer']) if checkpoint['scheduler']: self.lr_scheduler.load_state_dict(checkpoint['scheduler']) self.metric_to_watch = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch: {}, best_metric: {:.4f})" .format(resume, checkpoint['epoch'], self.metric_to_watch)) def close(self): fig = plot_confusion_matrix(self.evaluator.confusion_matrix, normalize=True, title=None, cmap=plt.cm.Blues, show=False) self.logger.log_artifact(fig, epoch=-1, name='confusion_matrix.png') self.logger.close()