Exemplo n.º 1
0
def train(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    teacher_model = None
    if args.teacher_model_path:
        teacher_model = PSPNet(layers=args.teacher_layers,
                               classes=args.classes,
                               zoom_factor=args.zoom_factor)
        kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature)
        args.save_path = os.path.join(args.save_path, kd_path)
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
    if args.arch == 'dct':
        model = DCTNet(layers=args.layers, classes=args.classes, vec_dim=300)
        # modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        # modules_new = [model.cls, model.aux]  # DCT4
        modules_ori = [model.cp, model.sp, model.head]
        modules_new = []
        args.index_split = len(
            modules_ori
        )  # the module after index_split need multiply 10 at learning rate

    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    # args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if teacher_model is not None:
            teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(
                teacher_model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)  # tensorboardX
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
        if teacher_model is not None:
            logger.info(teacher_model)
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
                                                          device_ids=[gpu])
        if teacher_model is not None:
            teacher_model = torch.nn.parallel.DistributedDataParallel(
                teacher_model.cuda(), device_ids=[gpu])

    else:
        model = torch.nn.DataParallel(model.cuda())
        if teacher_model is not None:
            teacher_model = torch.nn.DataParallel(teacher_model.cuda())

    if teacher_model is not None:
        checkpoint = torch.load(
            args.teacher_model_path,
            map_location=lambda storage, loc: storage.cuda())
        teacher_model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loading teacher checkpoint '{}'".format(
            args.teacher_model_path))

    if args.use_ohem:
        criterion = OhemCELoss(thresh=0.7,
                               ignore_index=args.ignore_label).cuda(gpu)
    else:
        criterion = nn.CrossEntropyLoss(
            ignore_index=args.ignore_label).cuda(gpu)

    kd_criterion = None
    if teacher_model is not None:
        kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(gpu)

    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight: '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> mp weight found at '{}'".format(args.weight))

    best_mIoU_val = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # Load all tensors onto GPU
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_iter = checkpoint['iteration']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_mIoU_val = checkpoint['best_mIoU_val']
            if main_process():
                logger.info("=> loaded checkpoint '{}' (iteration {})".format(
                    args.resume, checkpoint['iteration']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    value_scale = 255
    ## RGB mean & std
    rgb_mean = [0.485, 0.456, 0.406]
    rgb_mean = [item * value_scale for item in rgb_mean]
    rgb_std = [0.229, 0.224, 0.225]
    rgb_std = [item * value_scale for item in rgb_std]

    # DCT mean & std
    dct_mean = dct_mean_std.train_upscaled_static_mean
    dct_mean = [item * value_scale for item in dct_mean]
    dct_std = dct_mean_std.train_upscaled_static_std
    dct_std = [item * value_scale for item in dct_std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=rgb_mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=rgb_mean,
                       ignore_label=args.ignore_label),
        # transform.GetDctCoefficient(),
        transform.ToTensor(),
        transform.Normalize(mean=rgb_mean, std=rgb_std)
    ])
    train_data = dataset.SemData(split='train',
                                 img_type='rgb',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform)
    # train_transform = transform_rgbdct.Compose([
    #     transform_rgbdct.RandScale([args.scale_min, args.scale_max]),
    #     transform_rgbdct.RandRotate([args.rotate_min, args.rotate_max], padding=rgb_mean, ignore_label=args.ignore_label),
    #     transform_rgbdct.RandomGaussianBlur(),
    #     transform_rgbdct.RandomHorizontalFlip(),
    #     transform_rgbdct.Crop([args.train_h, args.train_w], crop_type='rand', padding=rgb_mean, ignore_label=args.ignore_label),
    #     transform_rgbdct.GetDctCoefficient(),
    #     transform_rgbdct.ToTensor(),
    #     transform_rgbdct.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std,  std_dct=dct_std)])
    # train_data = dataset.SemData(split='train', img_type='rgb&dct', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, \
        shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, \
        sampler=train_sampler, drop_last=True)
    if args.evaluate:
        # val_h = int(args.base_h * args.scale)
        # val_w = int(args.base_w * args.scale)
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=rgb_mean,
                           ignore_label=args.ignore_label),
            # transform.Resize(size=(val_h, val_w)),
            # transform.GetDctCoefficient(),
            transform.ToTensor(),
            transform.Normalize(mean=rgb_mean, std=rgb_std)
        ])
        val_data = dataset.SemData(split='val',
                                   img_type='rgb',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform)
        # val_transform = transform_rgbdct.Compose([
        #     transform_rgbdct.Crop([args.train_h, args.train_w], crop_type='center', padding=rgb_mean, ignore_label=args.ignore_label),
        #     # transform.Resize(size=(val_h, val_w)),
        #     transform_rgbdct.GetDctCoefficient(),
        #     transform_rgbdct.ToTensor(),
        #     transform_rgbdct.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std,  std_dct=dct_std)])
        # val_data = dataset.SemData(split='val', img_type='rgb&dct', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, \
            shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    # Training Loop
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    # aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    # switch to train mode
    model.train()
    if teacher_model is not None:
        teacher_model.eval()
    end = time.time()
    max_iter = args.max_iter

    data_iter = iter(train_loader)
    epoch = 0
    for current_iter in range(args.start_iter, args.max_iter):
        try:
            input, target = next(data_iter)
            if not target.size(0) == args.batch_size:
                raise StopIteration
        except StopIteration:
            epoch += 1
            if args.distributed:
                train_sampler.set_epoch(epoch)
                if main_process():
                    logger.info('train_sampler.set_epoch({})'.format(epoch))
            data_iter = iter(train_loader)
            input, target = next(data_iter)
            # need to update the AverageMeter for new epoch
            main_loss_meter = AverageMeter()
            # aux_loss_meter = AverageMeter()
            loss_meter = AverageMeter()
            intersection_meter = AverageMeter()
            union_meter = AverageMeter()
            target_meter = AverageMeter()
        # measure data loading time
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        # input = [input[0].cuda(non_blocking=True), input[1].cuda(non_blocking=True)]
        target = target.cuda(non_blocking=True)

        # compute output
        # main_out, aux_out = model(input)
        main_out = model(input)
        # _, H, W = target.shape
        # main_out = F.interpolate(main_out, size=(H, W), mode='bilinear', align_corners=True)
        main_loss = criterion(main_out, target)
        # aux_loss = criterion(aux_out, target)

        if not args.multiprocessing_distributed:
            # main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss)
            main_loss = torch.mean(main_loss)
        # loss = main_loss + args.aux_weight * aux_loss
        loss = main_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n = target.size(0)
        # if args.multiprocessing_distributed:
        #     main_loss, aux_loss, loss = main_loss.detach() * n, aux_loss * n, loss * n  # not considering ignore pixels
        #     count = target.new_tensor([n], dtype=torch.long)
        #     dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
        #     n = count.item()
        #     main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n
        if args.multiprocessing_distributed:
            main_loss, loss = main_loss.detach(
            ) * n, loss * n  # not considering ignore pixels
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(main_loss), dist.all_reduce(loss), dist.all_reduce(
                count)
            n = count.item()
            main_loss, loss = main_loss / n, loss / n

        main_out = main_out.detach().max(1)[1]
        intersection, union, target = intersectionAndUnionGPU(
            main_out, target, args.classes, args.ignore_label)
        if args.multiprocessing_distributed:
            dist.all_reduce(intersection), dist.all_reduce(
                union), dist.all_reduce(target)
        intersection, union, target = intersection.cpu().numpy(), union.cpu(
        ).numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)

        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        # aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        # Using Poly strategy to change the learning rate
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split
                           ):  # args.index_split = 5 -> ResNet has 5 stages
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10

        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        iter_log = current_iter + 1
        if iter_log % args.print_freq == 0 and main_process():
            logger.info('Iter [{}/{}] '
                        'LR: {lr:.3e}, '
                        'ETA: {remain_time}, '
                        'Data: {data_time.val:.3f} ({data_time.avg:.3f}), '
                        'Batch: {batch_time.val:.3f} ({batch_time.avg:.3f}), '
                        'MainLoss: {main_loss_meter.val:.4f}, '
                        # 'AuxLoss: {aux_loss_meter.val:.4f}, '
                        'Loss: {loss_meter.val:.4f}, '
                        'Accuracy: {accuracy:.4f}.'.format(
                            iter_log,
                            args.max_iter,
                            lr=current_lr,
                            remain_time=remain_time,
                            data_time=data_time,
                            batch_time=batch_time,
                            main_loss_meter=main_loss_meter,
                            # aux_loss_meter=aux_loss_meter,
                            loss_meter=loss_meter,
                            accuracy=accuracy))
        if main_process():
            writer.add_scalar('loss_train_batch', main_loss_meter.val,
                              iter_log)
            writer.add_scalar('mIoU_train_batch',
                              np.mean(intersection / (union + 1e-10)),
                              iter_log)
            writer.add_scalar('mAcc_train_batch',
                              np.mean(intersection / (target + 1e-10)),
                              iter_log)
            writer.add_scalar('allAcc_train_batch', accuracy, iter_log)

        if iter_log % len(
                train_loader
        ) == 0 or iter_log == max_iter:  # for each epoch or the max interation
            iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
            accuracy_class = intersection_meter.sum / (target_meter.sum +
                                                       1e-10)
            mIoU_train = np.mean(iou_class)
            mAcc_train = np.mean(accuracy_class)
            allAcc_train = sum(
                intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
            loss_train = main_loss_meter.avg
            if main_process():
                logger.info('Train result at iteration [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'\
                    .format(iter_log, max_iter, mIoU_train, mAcc_train, allAcc_train))
                writer.add_scalar('loss_train', loss_train, iter_log)
                writer.add_scalar('mIoU_train', mIoU_train, iter_log)
                writer.add_scalar('mAcc_train', mAcc_train, iter_log)
                writer.add_scalar('allAcc_train', allAcc_train, iter_log)

        # if iter_log % args.save_freq == 0:
            is_best = False
            if args.evaluate:
                loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                    val_loader, model, criterion)
                model.train()  # the mode change from eval() to train()
                if main_process():
                    writer.add_scalar('loss_val', loss_val, iter_log)
                    writer.add_scalar('mIoU_val', mIoU_val, iter_log)
                    writer.add_scalar('mAcc_val', mAcc_val, iter_log)
                    writer.add_scalar('allAcc_val', allAcc_val, iter_log)

                    if best_mIoU_val < mIoU_val:
                        is_best = True
                        best_mIoU_val = mIoU_val
                        logger.info('==>The best val mIoU: %.3f' %
                                    (best_mIoU_val))

            if main_process():
                save_checkpoint(
                    {
                        'iteration': iter_log,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_mIoU_val': best_mIoU_val
                    }, is_best, args.save_path)
                logger.info('Saving checkpoint to:{}/iter_{}.pth or last.pth with mIoU:{:.3f}'\
                    .format(args.save_path, iter_log, mIoU_val))
                if is_best:
                    logger.info('Saving checkpoint to:{}/best.pth with mIoU:{:.3f}'\
                        .format(args.save_path, best_mIoU_val))

    if main_process():
        writer.close(
        )  # it must close the writer, otherwise it will appear the EOFError!
        logger.info(
            '==>Training done! The best val mIoU during training: %.3f' %
            (best_mIoU_val))
Exemplo n.º 2
0
    plt.show()


if __name__ == "__main__":
    value_scale = 255
    mean = [74.8559803, 79.1187336, 80.7307415]
    # mean = [0.485, 0.456, 0.406]
    # mean = [item * value_scale for item in mean]
    # std = [0.229, 0.224, 0.225]
    std = [19.19655189, 19.56021428, 24.39020428]
    # std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.Resize((512, 512)),
        # transform.RandScale([0.5, 2]),
        transform.RandRotate([0, 45], padding=mean, ignore_label=0),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([256, 256],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=0),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])
    dataset = MyDataset("dataset/train", transform=train_transform)
    batchSize = 16
    validation_split = 2
    shuffle_dataset = True
    random_seed = 42
    train_size = int(0.8 * len(dataset))

if __name__ == '__main__':
    print(Path.db_root_dir('mf'))
    data_dir = './dataset'
    n_class = 9
    mean = [58.6573, 65.9755, 56.4990, 100.8296]
    std = [60.2980, 59.0457, 58.0989, 47.1224]
    scale_min, scale_max = 0.5, 2.0
    rotate_min, rotate_max = -10, 10
    train_h = 256  # default: 480
    train_w = 512  # default: 640
    train_transform = transform.Compose([
        transform.RandScale([scale_min, scale_max]),
        transform.RandRotate([rotate_min, rotate_max],
                             padding=mean,
                             ignore_label=255),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([train_h, train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=255),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    val_transform = transform.Compose(
        [transform.ToTensor(),
         transform.Normalize(mean=mean, std=std)])
Exemplo n.º 4
0
def main_worker(local_rank, ngpus_per_node, argss):
    global args
    args = argss

    dist.init_process_group(backend=args.dist_backend)

    teacher_model = None
    if args.teacher_model_path:
        teacher_model = PSPNet(layers=args.teacher_layers, classes=args.classes, zoom_factor=args.zoom_factor)
        kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature)
        args.save_path = os.path.join(args.save_path, kd_path)
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
    if args.arch == 'psp':
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'bise_v1':
        model = BiseNet(num_classes=args.classes)
        modules_ori = [model.sp, model.cp]
        modules_new = [model.ffm, model.conv_out, model.conv_out16, model.conv_out32]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if teacher_model is not None:
            teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path) # tensorboardX
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
        if teacher_model is not None:
            logger.info(teacher_model)
    if args.distributed:
        torch.cuda.set_device(local_rank)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank])
        if teacher_model is not None:
            teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model.cuda(), device_ids=[local_rank])

    else:
        model = torch.nn.DataParallel(model.cuda())
        if teacher_model is not None:
            teacher_model = torch.nn.DataParallel(teacher_model.cuda())
    
    if teacher_model is not None:
        checkpoint = torch.load(args.teacher_model_path, map_location=lambda storage, loc: storage.cuda())
        teacher_model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loading teacher checkpoint '{}'".format(args.teacher_model_path))
    
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda(local_rank)
    kd_criterion = None
    if teacher_model is not None:
        kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(local_rank)
            
    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight: '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> mp weight found at '{}'".format(args.weight))
    
    best_mIoU_val = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # Load all tensors onto GPU
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_mIoU_val = checkpoint['best_mIoU_val']
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['point']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(args.resume))    
        
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
        
    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)])

    train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)])
        val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            # Use .set_epoch() method to reshuffle the dataset partition at every iteration
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(local_rank, train_loader, model, teacher_model, criterion, kd_criterion, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
        
        is_best = False
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(local_rank, val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)

                if best_mIoU_val < mIoU_val:
                    is_best = True
                    best_mIoU_val = mIoU_val
                    logger.info('==>The best val mIoU: %.3f' % (best_mIoU_val))

        
        if (epoch_log % args.save_freq == 0) and main_process():
            save_checkpoint(
                {
                    'epoch': epoch_log, 
                    'state_dict': model.state_dict(), 
                    'optimizer': optimizer.state_dict(),
                    'best_mIoU_val': best_mIoU_val
                }, 
                is_best, 
                args.save_path
            )
            if is_best:
                logger.info('Saving checkpoint to:' + args.save_path + '/best.pth with mIoU: ' + str(best_mIoU_val) )
            else:
                logger.info('Saving checkpoint to:' + args.save_path + '/last.pth with mIoU: ' + str(mIoU_val) )

    if main_process():  
        writer.close() # it must close the writer, otherwise it will appear the EOFError!
        logger.info('==>Training done!\nBest mIoU: %.3f' % (best_mIoU_val))
Exemplo n.º 5
0
def train(gpu, ngpus_per_node, argss):
    global args
    args = argss

    if args.arch == 'triple':
        model = TriSeNet(layers=args.layers, classes=args.classes)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        # modules_new = [model.down_8_32, model.sa_8_32, model.seg_head]
        modules_new = []
        for key, value in model._modules.items():
            if "layer" not in key:
                modules_new.append(value)
    args.index_split = len(
        modules_ori
    )  # the module after index_split need multiply 10 at learning rate
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    print("=> creating model ...")
    print("Classes: {}".format(args.classes))
    print(model)
    model = torch.nn.DataParallel(model.cuda())
    cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda(gpu)

    value_scale = 255
    ## RGB mean & std
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    train_data = SemData(split='train',
                         data_root=args.data_root,
                         data_list=args.train_list,
                         transform=train_transform)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, \
        shuffle=True, num_workers=args.workers, pin_memory=True, \
        sampler=None, drop_last=True)

    # Training Loop
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.max_iter

    data_iter = iter(train_loader)
    epoch = 0
    for current_iter in range(args.start_iter, args.max_iter):
        try:
            input, target = next(data_iter)
            if not input.size(0) == args.batch_size:
                raise StopIteration
        except StopIteration:
            epoch += 1
            data_iter = iter(train_loader)
            input, target = next(data_iter)
            # need to update the AverageMeter for new epoch
            main_loss_meter = AverageMeter()
            aux_loss_meter = AverageMeter()
            loss_meter = AverageMeter()
            intersection_meter = AverageMeter()
            union_meter = AverageMeter()
            target_meter = AverageMeter()
        # measure data loading time
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        main_out = model(input)
        main_loss = criterion(main_out, target)
        aux_loss = torch.tensor(0).cuda()
        loss = main_loss + args.aux_weight * aux_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        n = input.size(0)
        main_out = main_out.detach().max(1)[1]
        intersection, union, target = intersectionAndUnionGPU(
            main_out, target, args.classes, args.ignore_label)
        intersection, union, target = intersection.cpu().numpy(), union.cpu(
        ).numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)

        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        # Using Poly strategy to change the learning rate
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split
                           ):  # args.index_split = 5 -> ResNet has 5 stages
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10

        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        iter_log = current_iter + 1

        if iter_log % args.print_freq == 0:
            print('Iteration: [{}/{}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                  'ETA {remain_time} '
                  'MainLoss {main_loss_meter.val:.4f} '
                  'AuxLoss {aux_loss_meter.val:.4f} '
                  'Loss {loss_meter.val:.4f} '
                  'Accuracy {accuracy:.4f}.'.format(
                      iter_log,
                      args.max_iter,
                      data_time=data_time,
                      batch_time=batch_time,
                      remain_time=remain_time,
                      main_loss_meter=main_loss_meter,
                      aux_loss_meter=aux_loss_meter,
                      loss_meter=loss_meter,
                      accuracy=accuracy))
    save_checkpoint(
        {
            'iteration': iter_log,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, False, args.save_path)