def main(): global args, best_record args = parser.parse_args() if args.augment: transform_train = joint_transforms.Compose([ joint_transforms.RandomCrop(256), joint_transforms.Normalize(), joint_transforms.ToTensor(), ]) else: transform_train = None dataset_train = Data.WData(args.data_root, transform_train) dataloader_train = data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=16) dataset_val = Data.WData(args.val_root, transform_train) dataloader_val = data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=None, num_workers=16) model = SFNet(input_channels=37, dilations=[2, 4, 8], num_class=2) # multi gpu model = torch.nn.DataParallel(model) print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) model = model.cuda() cudnn.benchmark = True # define loss function (criterion) and pptimizer criterion = torch.nn.CrossEntropyLoss(ignore_index=-1).cuda() optimizer = torch.optim.SGD([{ 'params': get_1x_lr_params(model) }, { 'params': get_10x_lr_params(model), 'lr': 10 * args.learning_rate }], lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch train(dataloader_train, model, criterion, optimizer, epoch) # evaluate on validation set acc, mean_iou, val_loss = validate(dataloader_val, model, criterion, epoch) is_best = mean_iou > best_record['miou'] if is_best: best_record['epoch'] = epoch best_record['val_loss'] = val_loss.avg best_record['acc'] = acc best_record['miou'] = mean_iou save_checkpoint( { 'epoch': epoch + 1, 'val_loss': val_loss.avg, 'accuracy': acc, 'miou': mean_iou, 'model': model, }, is_best) print( '------------------------------------------------------------------------------------------------------' ) print('[epoch: %d], [val_loss: %5f], [acc: %.5f], [miou: %.5f]' % (epoch, val_loss.avg, acc, mean_iou)) print( 'best record: [epoch: {epoch}], [val_loss: {val_loss:.5f}], [acc: {acc:.5f}], [miou: {miou:.5f}]' .format(**best_record)) print( '------------------------------------------------------------------------------------------------------' )