Esempio n. 1
0
File: test.py Progetto: whrws/semseg
def main():
    global args, logger
    args = get_parser()
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')

    test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split=args.split, data_root=args.data_root, data_list=args.test_list, transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step, len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = [line.rstrip('\n') for line in open(args.names_path)]

    if not args.has_prediction:
        if args.arch == 'psp':
            from model.pspnet import PSPNet
            model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False)
        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact,
                           shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False)
        logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.model_path):
            logger.info("=> loading checkpoint '{}'".format(args.model_path))
            checkpoint = torch.load(args.model_path)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            logger.info("=> loaded checkpoint '{}'".format(args.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
        test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, gray_folder, color_folder, colors)
    if args.split != 'test':
        cal_acc(test_data.data_list, gray_folder, args.classes, names)
Esempio n. 2
0
def get_train_transform_list(args, split, dataset_name):
    """
        Args:
        -   args:
        -   split

        Return:
        -   List of transforms
    """
    from util.normalization_utils import get_imagenet_mean_std
    from util import transform

    mean, std = get_imagenet_mean_std()
    if split == 'train':
        transform_list = [
            transform.RandScale([args.scale_min, args.scale_max]),
            transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
            transform.RandomGaussianBlur(),
            transform.RandomHorizontalFlip(),
            transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ]
    elif split == 'val':
        transform_list = [
            transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ]
    else:
        print('Unknown split. Quitting ...')
        quit()

    transform_list += [ToFlatLabel(args.tc, dataset_name)]

    return transform.Compose(transform_list)
Esempio n. 3
0
def main():
    # params parser
    global args, writer, logger
    args = get_parser()

    logger = get_logger()
    logger.info(args)
    logger.info("Classes: {}".format(args.classes))
    # params check
    check(args)
    # params set
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.train_gpu)
    # set random number
    if args.manual_seed is not None:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.manual_seed)
        np.random.seed(args.manual_seed)
        torch.cuda.manual_seed_all(args.manual_seed)

    # ----------------- data preprocessing ----------------- #
    value_scale = 255
    mean = args.mean
    mean = [item * value_scale for item in mean]
    std = args.std
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean),
        transform.RandomHorizontalFlip(),
        transform.RandomBilateralFilter(p=0.5),
        transform.RandomElastic(),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    val_transform = transform.Compose([
        # transform.RandomBilateralFilter(p=1),
        # transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    # split train & val
    train_kfolds, val_kfolds = k_fold_split(train_dir=args.train_image_dir,
                                            save_dir=args.txt_save_dir,
                                            k=args.folds,
                                            save=True)
    for fold_i, (train_image_label_list, val_image_label_list) in enumerate(
            zip(train_kfolds, val_kfolds)):
        print('>>>>>>>>>>>>>>>> Start Fold {} >>>>>>>>>>>>>>>>'.format(fold_i))
        # ----------------- Train setting ----------------- #
        # loss
        if args.loss == 'wbce':
            criterion = nn.BCEWithLogitsLoss(
                pos_weight=torch.tensor(args.edge_weight))
        elif args.loss == 'dilatedbce':
            criterion = dilatedweightBCE(
                kernel_size=3,
                bg_weight=args.bg_weight,
                dilated_bg_weight=args.dilated_bg_weight,
                edge_weight=args.edge_weight)
        elif args.loss == 'focal':
            criterion = FocalLoss(alpha=1,
                                  gamma=2,
                                  logits=True,
                                  weight=args.edge_weight,
                                  reduce=True)
        elif args.loss == 'dice':
            criterion = DiceLoss()
        elif args.loss == 'focal_dice':
            criterion = FocalDiceLoss(alpha=1,
                                      gamma=2,
                                      logits=True,
                                      weight=args.edge_weight,
                                      reduce=True)

        # model
        if args.arch == 'unet':
            model = UNet(n_classes=args.classes,
                         bilinear=args.bilinear_up,
                         criterion=criterion).cuda()
        elif args.arch == 'resnet_unet':
            model = adoptedUNet(layer=34,
                                use_ppm=True,
                                use_attention=False,
                                up_way=args.upway,
                                num_classes=args.classes,
                                pretrained=True,
                                criterion=criterion).cuda()
        elif args.arch == 'hed':
            model = HED(criterion=criterion).cuda()
        logger.info(model)

        # model parallel
        if len(args.train_gpu) > 1:
            logger.info("%d GPU parallel" % len(args.train_gpu))
            model = nn.DataParallel(model)

        # optimizer
        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.base_lr,
                                         betas=(0.9, 0.999),
                                         eps=1e-08,
                                         weight_decay=args.weight_decay,
                                         amsgrad=False)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.base_lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        elif args.optimizer == 'radam':
            optimizer = RAdam(model.parameters(), lr=args.base_lr)
            # Wrap it with Lookahead
            optimizer = Lookahead(optimizer, sync_rate=0.5, sync_period=6)

        # checkpoint resume
        if args.resume:
            if os.path.isfile(args.resume):
                logger.info("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda())
                args.start_epoch = checkpoint['epoch']
                model_dict = model.state_dict()
                old_dict = {
                    k: v
                    for k, v in checkpoint['state_dict'].items()
                    if (k in model_dict)
                }
                model_dict.update(old_dict)
                model.load_state_dict(model_dict)

                # model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

        # ---------------------- data loader ---------------------------- #
        train_image_label_list = train_image_label_list * 100
        save_path = os.path.join(args.model_save_dir, ('Fold' + str(fold_i)))
        global writer
        writer = SummaryWriter(save_path)
        # data loader for training
        train_data = dataset.SemData(split='train',
                                     data_root=args.data_root,
                                     data_list=train_image_label_list,
                                     transform=train_transform)
        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None,
                                                   drop_last=True)
        logger.info("Train set: %d" % (len(train_data)))

        # data loader for validation
        if args.evaluate:
            val_data = dataset.SemData(split='val',
                                       data_root=args.data_root,
                                       data_list=val_image_label_list,
                                       transform=val_transform)
            val_loader = torch.utils.data.DataLoader(
                val_data,
                batch_size=args.batch_size_val,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True,
                sampler=None)
            logger.info("val set: %d" % (len(val_data)))

        # ----------------- Train and Val ----------------- #
        for epoch in range(args.start_epoch, args.epochs):
            epoch_log = epoch + 1
            # train
            loss_train, mAcc_train, mFscore_train = train(
                train_loader, model, optimizer, epoch)
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('mFscore_train', mFscore_train, epoch_log)

            # save model
            if epoch_log % args.save_freq == 0:

                filename = save_path + '/train_epoch_' + str(
                    epoch_log) + '.pth'
                logger.info('Saving checkpoint to: ' + filename)
                torch.save(
                    {
                        'epoch': epoch_log,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, filename)
                if epoch_log / args.save_freq > 20:
                    deletename = save_path + '/train_epoch_' + str(
                        epoch_log - args.save_freq * 20) + '.pth'
                    os.remove(deletename)

            # val
            if args.evaluate:
                with torch.no_grad():
                    loss_val, mAcc_val, mFscore_val, max_threshold = validate(
                        val_loader, model, criterion)
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('mFscore_val', mFscore_val, epoch_log)
                writer.add_scalar('max_threshold', max_threshold, epoch_log)
Esempio n. 4
0
def main():
    global args, logger
    args = get_parser()
    # check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gen_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder.replace('ss', 'video'), 'gray')

    test_transform = transform.Compose(
        [transform.ToTensor(),
         transform.Normalize(mean=mean, std=std)])
    test_data = dataset.SemData(
        split='test',
        data_root=args.data_root,
        data_list='./data/list/cityscapes/val_video_img_sam.lst',
        transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size_gen,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')

    if not args.has_prediction:
        if args.arch == 'psp':
            from model.origin_pspnet import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           pretrained=False)
        elif args.arch == 'psp18':
            from model.pspnet_18 import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           flow=False,
                           pretrained=False)

        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           compact=args.compact,
                           shrink_factor=args.shrink_factor,
                           mask_h=args.mask_h,
                           mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor,
                           psa_softmax=args.psa_softmax,
                           pretrained=False)
        elif args.arch == 'mobile':
            from model.mobile import DenseASPP
            model = DenseASPP(layers=args.layers,
                              classes=args.classes,
                              zoom_factor=args.zoom_factor,
                              flow=False)
        elif args.arch == 'antipsp18':
            from model.antipspnet18 import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           flow=False)
        logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.ckpt_path):
            logger.info("=> loading checkpoint '{}'".format(args.ckpt_path))
            checkpoint = torch.load(args.ckpt_path)
            student_ckpt = transfer_ckpt(checkpoint)
            a, b = model.load_state_dict(student_ckpt, strict=False)
            print('unexpected keys:', a)
            print('missing keys:', b)
            logger.info("=> loaded checkpoint '{}'".format(args.ckpt_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.ckpt_path))

        test(test_loader, test_data.data_list, model, args.classes, mean, std,
             args.base_size, 1024, 2048, args.scales, gray_folder, colors)
Esempio n. 5
0
def main():
    init()
    if args.arch == 'pointnet_seg':
        from model.pointnet.pointnet import PointNetSeg as Model
    elif args.arch == 'pointnet2_seg':
        from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model
    elif args.arch == 'pointweb_seg':
        from model.pointweb.pointweb_seg import PointWebSeg as Model
    else:
        raise Exception('architecture not supported yet'.format(args.arch))
    model = Model(c=args.fea_dim, k=args.classes, use_xyz=args.use_xyz)
    if args.sync_bn:
        from util.util import convert_to_syncbn
        from lib.sync_bn import patch_replication_callback
        convert_to_syncbn(model), patch_replication_callback(model)
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.step_epoch,
                                    gamma=args.multiplier)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))
    logger.info(model)
    model = torch.nn.DataParallel(model.cuda())
    if args.weight:
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    train_transform = transform.Compose([transform.ToTensor()])
    if args.data_name == 's3dis':
        train_data = S3DIS(split='train',
                           data_root=args.train_full_folder,
                           num_point=args.num_point,
                           test_area=args.test_area,
                           block_size=args.block_size,
                           sample_rate=args.sample_rate,
                           transform=train_transform)
        # train_data = dataset.PointData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    elif args.data_name == 'scannet':
        train_data = ScanNet(split='train',
                             data_root=args.data_root,
                             num_point=args.num_point,
                             block_size=args.block_size,
                             sample_rate=args.sample_rate,
                             transform=train_transform)
    elif args.data_name == 'modelnet40':
        train_data = dataset.PointData(split='train',
                                       data_root=args.data_root,
                                       data_list=args.train_list,
                                       transform=train_transform,
                                       num_point=args.num_point,
                                       random_index=True)
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.train_batch_size,
        shuffle=True,
        num_workers=args.train_workers,
        pin_memory=True)

    val_loader = None
    if args.evaluate:
        val_transform = transform.Compose([transform.ToTensor()])
        val_data = dataset.PointData(split='val',
                                     data_root=args.data_root,
                                     data_list=args.val_list,
                                     transform=val_transform)
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.train_batch_size_val,
            shuffle=False,
            num_workers=args.train_workers,
            pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, criterion, optimizer, epoch)
        epoch_log = epoch + 1
        writer.add_scalar('loss_train', loss_train, epoch_log)
        writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
        writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
        writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if epoch_log % args.save_freq == 0:
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()
                }, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(
                    epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            writer.add_scalar('loss_val', loss_val, epoch_log)
            writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
            writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
            writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Esempio n. 6
0
def main():
    init()
    if args.arch == 'pointnet_seg':
        from model.pointnet.pointnet import PointNetSeg as Model
    elif args.arch == 'pointnet2_seg':
        from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model
    elif args.arch == 'pointnet2_paconv_seg':
        from model.pointnet2.pointnet2_paconv_seg import PointNet2SSGSeg as Model
    else:
        raise Exception('architecture not supported yet'.format(args.arch))
    model = Model(c=args.fea_dim,
                  k=args.classes,
                  use_xyz=args.use_xyz,
                  args=args)

    best_mIoU = 0.0

    if args.sync_bn:
        from util.util import convert_to_syncbn
        convert_to_syncbn(model)
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.get('lr_multidecay', False):
        scheduler = lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(args.epochs * 0.6),
                        int(args.epochs * 0.8)],
            gamma=args.multiplier)
    else:
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=args.step_epoch,
                                        gamma=args.multiplier)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))
    logger.info(model)
    model = torch.nn.DataParallel(model.cuda())
    if args.sync_bn:
        from lib.sync_bn import patch_replication_callback
        patch_replication_callback(model)
    if args.weight:
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            try:
                best_mIoU = checkpoint['val_mIoU']
            except Exception:
                pass
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    if args.get('no_transformation', True):
        train_transform = None
    else:
        train_transform = transform.Compose([
            transform.RandomRotate(along_z=args.get('rotate_along_z', True)),
            transform.RandomScale(scale_low=args.get('scale_low', 0.8),
                                  scale_high=args.get('scale_high', 1.2)),
            transform.RandomJitter(sigma=args.get('jitter_sigma', 0.01),
                                   clip=args.get('jitter_clip', 0.05)),
            transform.RandomDropColor(
                color_augment=args.get('color_augment', 0.0))
        ])
    logger.info(train_transform)
    if args.data_name == 's3dis':
        train_data = S3DIS(split='train',
                           data_root=args.train_full_folder,
                           num_point=args.num_point,
                           test_area=args.test_area,
                           block_size=args.block_size,
                           sample_rate=args.sample_rate,
                           transform=train_transform,
                           fea_dim=args.get('fea_dim', 6),
                           shuffle_idx=args.get('shuffle_idx', False))
    else:
        raise ValueError('{} dataset not supported.'.format(args.data_name))
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.train_batch_size,
        shuffle=True,
        num_workers=args.train_workers,
        pin_memory=True,
        drop_last=True)

    val_loader = None
    if args.evaluate:
        val_transform = transform.Compose([transform.ToTensor()])
        if args.data_name == 's3dis':
            val_data = dataset.PointData(split='val',
                                         data_root=args.data_root,
                                         data_list=args.val_list,
                                         transform=val_transform,
                                         norm_as_feat=args.get(
                                             'norm_as_feat', True),
                                         fea_dim=args.get('fea_dim', 6))
        else:
            raise ValueError('{} dataset not supported.'.format(
                args.data_name))

        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.train_batch_size_val,
            shuffle=False,
            num_workers=args.train_workers,
            pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs):
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, criterion, optimizer, epoch,
            args.get('correlation_loss', False))
        epoch_log = epoch + 1
        writer.add_scalar('loss_train', loss_train, epoch_log)
        writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
        writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
        writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if epoch_log % args.save_freq == 0:
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'commit_id': get_git_commit_id()
                }, filename)
            if epoch_log / args.save_freq > 2:
                try:
                    deletename = args.save_path + '/train_epoch_' + str(
                        epoch_log - args.save_freq * 2) + '.pth'
                    os.remove(deletename)
                except Exception:
                    logger.info('{} Not found.'.format(deletename))

        if args.evaluate and epoch_log % args.get('eval_freq', 1) == 0:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            writer.add_scalar('loss_val', loss_val, epoch_log)
            writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
            writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
            writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
            if mIoU_val > best_mIoU:
                best_mIoU = mIoU_val
                filename = args.save_path + '/best_train.pth'
                logger.info('Best Model Saving checkpoint to: ' + filename)
                torch.save(
                    {
                        'epoch': epoch_log,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'val_mIoU': best_mIoU,
                        'commit_id': get_git_commit_id()
                    }, filename)
        scheduler.step()
Esempio n. 7
0
File: test.py Progetto: lixin666/MGL
def main():
    global args, logger
    args = get_parser('config/cod_mgl50.yaml')
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    date_str = str(datetime.datetime.now().date())
    save_folder = args.save_folder + '/' + date_str
    check_makedirs(save_folder)
    cod_folder = os.path.join(save_folder, 'cod')
    coee_folder = os.path.join(save_folder, 'coee')

    test_transform = transform.Compose([
        transform.Resize((args.test_h, args.test_w)),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    test_data = dataset.SemData(split=args.split,
                                data_root=args.data_root,
                                data_list=args.test_list,
                                transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    if not args.has_prediction:
        if args.arch == 'mgl':
            from model.mglnet import MGLNet
            model = MGLNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           pretrained=False,
                           args=args)
        #logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.model_path):
            logger.info("=> loading checkpoint '{}'".format(args.model_path))
            checkpoint = torch.load(args.model_path, map_location='cuda:0')
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            logger.info("=> loaded checkpoint '{}', epoch {}".format(
                args.model_path, checkpoint['epoch']))

        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.model_path))
        test(test_loader, test_data.data_list, model, cod_folder, coee_folder)
    if args.split != 'test':
        calc_acc(test_data.data_list, cod_folder, coee_folder)
Esempio n. 8
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
    if args.arch == 'psp':
        from model.kdnet import KDNet
        model = KDNet(layers=args.layers,
                      classes=args.classes,
                      zoom_factor=args.zoom_factor,
                      criterion=criterion,
                      temperature=args.temperature,
                      alpha=args.alpha)
        modules_ori = [
            model.student_net.layer0, model.student_net.layer1,
            model.student_net.layer2, model.student_net.layer3,
            model.student_net.layer4
        ]
        modules_new = [
            model.student_net.ppm, model.student_net.cls, model.student_net.aux
        ]
        teacher_net = model.teacher_loader
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       psa_type=args.psa_type,
                       compact=args.compact,
                       shrink_factor=args.shrink_factor,
                       mask_h=args.mask_h,
                       mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor,
                       psa_softmax=args.psa_softmax,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.psa, model.cls, model.aux]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(
            model.cuda(), device_ids=[gpu], find_unused_parameters=True)
    else:
        model = torch.nn.DataParallel(model.cuda())

    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])
    train_data = dataset.SemData(split='train',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
        val_data = dataset.SemData(split='val',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if (epoch_log % args.save_freq == 0) and main_process():
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(
                    epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Esempio n. 9
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.sync_bn:
        if args.multiprocessing_distributed:
            BatchNorm = apex.parallel.SyncBatchNorm
        else:
            from lib.sync_bn.modules import BatchNorm2d
            BatchNorm = BatchNorm2d
    else:
        BatchNorm = nn.BatchNorm2d
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
    #criterion = OhemCrossEntropyLoss(ignore_index=args.ignore_label)
    if args.arch == 'spnet':
        from models.spnet import SPNet
        model = SPNet(nclass=args.classes,
                      backbone=args.backbone,
                      pretrained=args.weight,
                      criterion=criterion,
                      norm_layer=BatchNorm,
                      spm_on=args.spm_on)
        print(model)
        modules_ori = [model.pretrained]
        modules_new = [model.head, model.auxlayer]
    else:
        if main_process():
            raise RuntimeError("=> Unknown network architecture: {}".format(
                args.arch))

    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 1
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(args.workers / ngpus_per_node)
        if args.use_apex:
            model, optimizer = apex.amp.initialize(
                model.cuda(),
                optimizer,
                opt_level=args.opt_level,
                keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                loss_scale=args.loss_scale)
            model = apex.parallel.DistributedDataParallel(model)
        else:
            model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
                                                              device_ids=[gpu])

    else:
        model = torch.nn.DataParallel(model.cuda())

    #if args.weight:
    #    if os.path.isfile(args.weight):
    #        if main_process():
    #            logger.info("=> loading weight '{}'".format(args.weight))
    #        checkpoint = torch.load(args.weight)
    #        model.load_state_dict(checkpoint['state_dict'])
    #        if main_process():
    #            logger.info("=> loaded weight '{}'".format(args.weight))
    #    else:
    #        if main_process():
    #            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    train_data = None
    if args.dataset == 'ade20k':
        train_data = dataset.Ade20kData(split='train',
                                        data_root=args.data_root,
                                        data_list=args.train_list,
                                        transform=train_transform)
    elif args.dataset == 'cityscapes':
        train_data = dataset.CityscapesData(split='train',
                                            data_root=args.data_root,
                                            data_list=args.train_list,
                                            transform=train_transform)
    else:
        if main_process():
            raise RuntimeError("=> Unsupported dataset: {}".format(
                args.dataset))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
        val_data = dataset.CityscapesData(split='val',
                                          data_root=args.data_root,
                                          data_list=args.val_list,
                                          transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if (epoch_log % args.save_freq == 0) and main_process():
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(
                    epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Esempio n. 10
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       criterion=criterion,
                       args=args)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       psa_type=args.psa_type,
                       compact=args.compact,
                       shrink_factor=args.shrink_factor,
                       mask_h=args.mask_h,
                       mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor,
                       psa_softmax=args.psa_softmax,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.psa, model.cls, model.aux]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
    else:
        logger = None
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
                                                          device_ids=[gpu])
    else:
        model = torch.nn.DataParallel(model.cuda())

    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume != 'none':
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            # model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # print(checkpoint['optimizer'].keys())
            if args.if_remove_cls:
                if main_process():
                    logger.info(
                        '=====!!!!!!!===== Remove cls layer in resuming...')
                checkpoint['state_dict'] = {
                    x: checkpoint['state_dict'][x]
                    for x in checkpoint['state_dict'].keys()
                    if ('module.cls' not in x and 'module.aux' not in x)
                }
                # checkpoint['optimizer'] = {x: checkpoint['optimizer'][x] for x in checkpoint['optimizer'].keys() if ('module.cls' not in x and 'module.aux' not in x)}
                # if main_process():
                #     print('----', checkpoint['state_dict'].keys())
                #     print('----', checkpoint['optimizer'].keys())
                #     print('----1', checkpoint['optimizer']['state'].keys())

            model.load_state_dict(checkpoint['state_dict'], strict=False)
            if not args.if_remove_cls:
                optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    transform_list_train = []
    if args.resize:
        transform_list_train.append(
            transform.Resize((args.resize_h, args.resize_w)))
    transform_list_train += [
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ]
    train_transform = transform.Compose(transform_list_train)
    train_data = dataset.SemData(split='val',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform,
                                 logger=logger,
                                 is_master=main_process(),
                                 args=args)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    if args.evaluate:
        transform_list_val = []
        if args.resize:
            transform_list_val.append(
                transform.Resize((args.resize_h, args.resize_w)))
        transform_list_val += [
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ]
        val_transform = transform.Compose(transform_list_val)
        val_data = dataset.SemData(split='val',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform,
                                   is_master=main_process(),
                                   args=args)
        args.read_image = val_data.read_image
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1

        # if args.evaluate and args.val_every_iter == -1:
        #     # logger.info('Validating.....')
        #     loss_val, mIoU_val, mAcc_val, allAcc_val, return_dict = validate(val_loader, model, criterion, args)
        #     if main_process():
        #         writer.add_scalar('VAL/loss_val', loss_val, epoch_log)
        #         writer.add_scalar('VAL/mIoU_val', mIoU_val, epoch_log)
        #         writer.add_scalar('VAL/mAcc_val', mAcc_val, epoch_log)
        #         writer.add_scalar('VAL/allAcc_val', allAcc_val, epoch_log)

        #         for sample_idx in range(len(return_dict['image_name_list'])):
        #             writer.add_text('VAL-image_name/%d'%sample_idx, return_dict['image_name_list'][sample_idx], epoch)
        #             writer.add_image('VAL-image/%d'%sample_idx, return_dict['im_list'][sample_idx], epoch, dataformats='HWC')
        #             writer.add_image('VAL-color_label/%d'%sample_idx, return_dict['color_GT_list'][sample_idx], epoch, dataformats='HWC')
        #             writer.add_image('VAL-color_pred/%d'%sample_idx, return_dict['color_pred_list'][sample_idx], epoch, dataformats='HWC')

        if args.distributed:
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, optimizer, epoch, epoch_log, val_loader,
            criterion)
        if main_process():
            writer.add_scalar('TRAIN/loss_train', loss_train, epoch_log)
            writer.add_scalar('TRAIN/mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('TRAIN/mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('TRAIN/allAcc_train', allAcc_train, epoch_log)
Esempio n. 11
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.seg_loss_type == 'ce':
        if args.ohem:
            min_kept = int(args.batch_size // len(args.train_gpu) *
                           args.train_h * args.train_w // 16)
            seg_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                                   thresh=0.7,
                                                   min_kept=min_kept,
                                                   use_weight=False)
        else:
            seg_criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
    else:
        raise NotImplementedError

    if args.derain_loss_type == 'mse':
        derain_criterion = nn.MSELoss()
    else:
        raise NotImplementedError

    if args.arch == 'iterative_derain_seg':
        from model.dic_arch_derainseg_fineutne import DIC
        model = DIC(args,
                    derain_criterion=derain_criterion,
                    seg_criterion=seg_criterion,
                    is_train=True)
        modules_ori = [model.seg_net]
        modules_new = [
            model.block, model.first_block, model.conv_in, model.conv_out,
            model.derain_final_conv
        ]
        modules_fix = [model.edge_net]
    else:
        raise NotImplementedError

    params_list = []
    for module in modules_ori:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 0))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 1))
    for module in modules_fix:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 0))
    args.index_split_1 = 1
    args.index_split_2 = 6
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
                                                          device_ids=[gpu])
    else:
        model = torch.nn.DataParallel(model.cuda())

    if args.pretrained:
        if main_process():
            logger.info("=> Loading derain first weight from '{}'\n "
                        "and '{}'\n seg weight from '{}'".format(
                            args.derain_first_pretrained_path,
                            args.derain_last_pretrained_path,
                            args.seg_pretrained_path))
        load_derain_and_seg(model, args)
        if main_process():
            logger.info("=> Loaded derain first weight from '{}'\n "
                        "and '{}'\n seg weight from '{}'".format(
                            args.derain_first_pretrained_path,
                            args.derain_last_pretrained_path,
                            args.seg_pretrained_path))

    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        # transform.RandScale([args.scale_min, args.scale_max]),
        # transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        # transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.RandomVerticalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])
    train_data = dataset.SemData(split='train',
                                 data_root=args.data_root,
                                 rain_data_root=args.rain_data_root,
                                 data_list=args.train_list,
                                 transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               sampler=train_sampler,
                                               drop_last=True)
    # if args.evaluate:
    #     val_transform = transform.Compose([
    #         transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
    #         transform.ToTensor(),
    #         transform.Normalize(mean=mean, std=std)])
    #     val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
    #     if args.distributed:
    #         val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
    #     else:
    #         val_sampler = None
    #     val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train, psnr_train, ssim_train = train(
            train_loader, model, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('psnr_train', psnr_train, epoch_log)
            writer.add_scalar('ssim_train', ssim_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if (epoch_log % args.save_freq == 0) and main_process():
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
Esempio n. 12
0
def main():
    global args
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
    model = PSPNet(layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   criterion=criterion,
                   pretrained=args.pretrained,
                   naive_ppm=args.naive_ppm)

    # set diffrent learning rate on different part of models
    modules_ori = [
        model.layer0, model.layer1, model.layer2, model.layer3, model.layer4
    ]
    modules_new = [model.ppm, model.cls_head, model.aux_head]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    global logger, writer
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))
    logger.info(model)
    model = model.cuda()
    model = torch.nn.DataParallel(model).cuda()

    if args.weight:
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    # image pre-processing and augmentation
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    train_transform = transform.Compose([
        transform.Resize((args.train_h, args.train_w)),
        # augmentation
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    # initialize dataloader
    train_data = dataset.SemData(split='trainval', transform=train_transform)
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
        val_data = dataset.SemData(split='test', transform=val_transform)
        val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    # start training
    logger.info('Starting training.')
    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, optimizer, epoch)
        writer.add_scalar('loss/train', loss_train, epoch_log)
        writer.add_scalar('mIoU/train', mIoU_train, epoch_log)
        writer.add_scalar('mAcc/train', mAcc_train, epoch_log)
        writer.add_scalar('allAcc/train', allAcc_train, epoch_log)

        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            writer.add_scalar('loss/val', loss_val, epoch_log)
            writer.add_scalar('mIoU/val', mIoU_val, epoch_log)
            writer.add_scalar('mAcc/val', mAcc_val, epoch_log)
            writer.add_scalar('allAcc/val', allAcc_val, epoch_log)

        if (epoch_log % args.save_freq == 0):
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
def get_dataloder():
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    assert args.split in [0, 1, 2, 3, 999]
    train_transform = [
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.padding_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.padding_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ]
    train_transform = transform.Compose(train_transform)
    train_data = dataset.SemData(split=args.split, shot=args.shot, max_sp=args.max_sp, data_root=args.data_root, \
                                data_list=args.train_list, transform=train_transform, mode='train', \
                                use_coco=args.use_coco, use_split_coco=args.use_split_coco)

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)
    if args.evaluate:
        if args.resized_val:
            val_transform = transform.Compose([
                transform.Resize(size=args.val_size),
                transform.ToTensor(),
                transform.Normalize(mean=mean, std=std)
            ])
        else:
            val_transform = transform.Compose([
                transform.test_Resize(size=args.val_size),
                transform.ToTensor(),
                transform.Normalize(mean=mean, std=std)
            ])
        val_data = dataset.SemData(split=args.split, shot=args.shot, max_sp=args.max_sp, data_root=args.data_root, \
                                data_list=args.val_list, transform=val_transform, mode='val', \
                                use_coco=args.use_coco, use_split_coco=args.use_split_coco)
        val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    return train_loader, val_loader
Esempio n. 14
0
def main():
    # params parser
    global args, logger
    args = get_parser()
    # params check
    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    # ----------------- Test setting ----------------- #
    # load model
    if args.arch == 'unet':
        model = UNet(n_classes=args.classes, bilinear=args.bilinear_up).cuda()
    elif args.arch == 'resnet_unet':
        model = adoptedUNet(layer=34,
                            use_ppm=True,
                            use_attention=False,
                            up_way=args.upway,
                            num_classes=args.classes).cuda()
    elif args.arch == 'hed':
        model = HED().cuda()
    logger.info(model)
    if len(args.train_gpu) > 1:
        model = torch.nn.DataParallel(model)
    cudnn.benchmark = False

    # ----------------- data loader ----------------- #
    value_scale = 255
    mean = args.mean
    mean = [item * value_scale for item in mean]
    std = args.std
    std = [item * value_scale for item in std]

    test_transform = transform.Compose([
        # transform.RandomBilateralFilter(p=1),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    test_image_path = os.path.join(args.test_image_dir, "*.png")
    test_image_list = glob(test_image_path)
    test_image_list = tuple(zip(test_image_list, test_image_list))
    test_data = dataset.SemData(split=args.split,
                                data_root=args.data_root,
                                data_list=test_image_list,
                                transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    if len(args.model_path) != 0:

        for fold_i in range(args.folds):
            for model_i in args.model_path:
                single_model_path = args.model_save_dir + 'Fold{}/train_epoch_{}.pth'.format(
                    fold_i, model_i)
                single_save_folder = args.result_save_dir + 'Fold{}/epoch_{}/'.format(
                    fold_i, model_i)
                if os.path.isfile(single_model_path):
                    logger.info(
                        "=> loading checkpoint '{}'".format(single_model_path))
                    checkpoint = torch.load(single_model_path)
                    model.load_state_dict(checkpoint['state_dict'],
                                          strict=False)
                    logger.info(
                        "=> loaded checkpoint '{}'".format(single_model_path))
                else:
                    raise RuntimeError("=> no checkpoint found at '{}'".format(
                        args.model_path))
                # test(test_loader, test_data.data_list, model, args.classes, args.base_size,
                #      args.test_h, args.test_w, args.scales, single_save_folder)

    if len(args.model_path) != 0:
        ensemble(test_data.data_list, args.base_size, args.base_size,
                 args.ensemble_way, args.threshold)

    if args.split != 'test':
        cal_acc(test_data.data_list, args.ensemble_folder, args.classes)
Esempio n. 15
0
def main(
        config_name,
        weights_url='https://github.com/deepparrot/semseg/releases/download/0.1/pspnet50-ade20k.pth',
        weights_name='pspnet50-ade20k.pth'):

    args = config.load_cfg_from_cfg_file(config_name)
    check(args)

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')

    args.data_root = './.data/vision/ade20k'
    args.val_list = './.data/vision/ade20k/validation.txt'
    args.test_list = './.data/vision/ade20k/validation.txt'

    print(args.data_root)

    test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split=args.split,
                                data_root=args.data_root,
                                data_list=args.test_list,
                                transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = []

    if not args.has_prediction:
        if args.arch == 'psp':
            from model.pspnet import PSPNet
            model = PSPNet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           pretrained=False)
        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers,
                           classes=args.classes,
                           zoom_factor=args.zoom_factor,
                           compact=args.compact,
                           shrink_factor=args.shrink_factor,
                           mask_h=args.mask_h,
                           mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor,
                           psa_softmax=args.psa_softmax,
                           pretrained=False)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

        local_checkpoint, _ = urllib.request.urlretrieve(
            weights_url, weights_name)

        if os.path.isfile(local_checkpoint):
            checkpoint = torch.load(local_checkpoint)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
        else:
            raise RuntimeError(
                "=> no checkpoint found at '{}'".format(local_checkpoint))
        test(test_loader, test_data.data_list, model, args.classes, mean, std,
             args.base_size, args.test_h, args.test_w, args.scales,
             gray_folder, color_folder, colors)
    if args.split != 'test':
        cal_acc(test_data.data_list, gray_folder, args.classes, names)
Esempio n. 16
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss

    BatchNorm = nn.BatchNorm2d

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)

    model = eval(args.arch).Model(args)

    for param in model.layer0.parameters():
        param.requires_grad = False
    for param in model.layer1.parameters():
        param.requires_grad = False
    for param in model.layer2.parameters():
        param.requires_grad = False
    for param in model.layer3.parameters():
        param.requires_grad = False
    for param in model.layer4.parameters():
        param.requires_grad = False
        
    optimizer = model._optimizer(args)
    global logger, writer
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))
    logger.info(model)
    print(args)

    model = torch.nn.DataParallel(model.cuda(), device_ids=[0])

    if args.weight:
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))


    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    assert args.split in [0, 1, 2, 3, 999]
    train_transform = [
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.padding_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.padding_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)]
    train_transform = transform.Compose(train_transform)
    train_data = dataset.SemData(split=args.split, shot=args.shot, max_sp=args.max_sp, data_root=args.data_root, \
                                data_list=args.train_list, transform=train_transform, mode='train', \
                                use_coco=args.use_coco, use_split_coco=args.use_split_coco)

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    if args.evaluate:
        if args.resized_val:
            val_transform = transform.Compose([
                transform.Resize(size=args.val_size),
                transform.ToTensor(),
                transform.Normalize(mean=mean, std=std)])    
        else:
            val_transform = transform.Compose([
                transform.test_Resize(size=args.val_size),
                transform.ToTensor(),
                transform.Normalize(mean=mean, std=std)])           
        val_data = dataset.SemData(split=args.split, shot=args.shot, max_sp=args.max_sp, data_root=args.data_root, \
                                data_list=args.val_list, transform=val_transform, mode='val', \
                                use_coco=args.use_coco, use_split_coco=args.use_split_coco)
        val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    max_iou = 0.
    filename = 'ASGNet.pth'

    for epoch in range(args.start_epoch, args.epochs):
        if args.fix_random_seed_val:
            torch.cuda.manual_seed(args.manual_seed + epoch)
            np.random.seed(args.manual_seed + epoch)
            torch.manual_seed(args.manual_seed + epoch)
            torch.cuda.manual_seed_all(args.manual_seed + epoch)
            random.seed(args.manual_seed + epoch)

        epoch_log = epoch + 1
        loss_train, aux_loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('aux_loss_train', aux_loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)     

        if args.evaluate and (epoch % 2 == 0 or (args.epochs<=50 and epoch%1==0)):
            loss_val, mIoU_val, mAcc_val, allAcc_val, class_miou = validate(val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('class_miou_val', class_miou, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
            if class_miou > max_iou:
                max_iou = class_miou
                if os.path.exists(filename):
                    os.remove(filename)            
                filename = args.save_path + '/train_epoch_' + str(epoch) + '_'+str(max_iou)+'.pth'
                logger.info('Saving checkpoint to: ' + filename)
                torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)

    filename = args.save_path + '/final.pth'
    logger.info('Saving checkpoint to: ' + filename)
    torch.save({'epoch': args.epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)                
Esempio n. 17
0
def main():
    global args, logger
    args = get_parser()
    if args.test_in_nyu_label_space:
        args.colors_path = 'nyu/nyu_colors.txt'
        args.names_path = 'nyu/nyu_names.txt'

    if args.if_cluster:
        args.data_root = args.data_root_cluster
        args.project_path = args.project_path_cluster
        args.data_config_path = 'data'
    for key in ['train_list', 'val_list', 'test_list', 'colors_path', 'names_path']:
        args[key] = os.path.join(args.data_config_path, args[key])
    for key in ['save_path', 'model_path', 'save_folder']:
        args[key] = os.path.join(args.project_path, args[key])
    # for key in ['save_path', 'model_path', 'save_folder']:
    #     args[key] = args[key] % args.exp_name

    check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')

    transform_list_test = []
    if args.resize:
        transform_list_test.append(transform.Resize((args.resize_h_test, args.resize_w_test)))
    transform_list_test += [
        transform.Crop([args.test_h, args.test_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
        transform.ToTensor(), 
        transform.Normalize(mean=mean, std=std)
    ]
    test_transform = transform.Compose(transform_list_test)
    test_data = dataset.SemData(split=args.split, data_root=args.data_root, data_list=args.test_list, transform=test_transform, is_master=True, args=args)
    # test_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=test_transform, is_master=True, args=args)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step, len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = [line.rstrip('\n') for line in open(args.names_path)]

    args.read_image = test_data.read_image


    if not args.has_prediction:
        if args.arch == 'psp':
            from model.pspnet import PSPNet
            model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, pretrained=False)
        elif args.arch == 'psa':
            from model.psanet import PSANet
            model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, compact=args.compact,
                           shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w,
                           normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, pretrained=False)
        logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.model_path):
            logger.info("=> loading checkpoint '{}'".format(args.model_path))
            checkpoint = torch.load(args.model_path)
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            logger.info("=> loaded checkpoint '{}'".format(args.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
        pred_path_list, target_path_list = test(test_loader, test_data.data_list, model, args.classes, mean, std, args.base_size, args.test_h, args.test_w, args.scales, gray_folder, color_folder, colors)
    if args.split != 'test' or (args.split == 'test' and args.test_has_gt):
        cal_acc(test_data.data_list, gray_folder, args.classes, names, pred_path_list=pred_path_list, target_path_list=target_path_list)
Esempio n. 18
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss

    ## step.1 设置分布式相关参数
    # 1.1 分布式初始化
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)  # 分布式初始化

    ## step.2 构建网络
    # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------#
    criterion = nn.CrossEntropyLoss(
        ignore_index=args.ignore_label)  # 交叉熵损失函数, 根据情况自己修改
    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       psa_type=args.psa_type,
                       compact=args.compact,
                       shrink_factor=args.shrink_factor,
                       mask_h=args.mask_h,
                       mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor,
                       psa_softmax=args.psa_softmax,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.psa, model.cls, model.aux]
    # ---------------------------------------------------- END ---------------------------------------------------#

    ## step.3 设置优化器
    params_list = []  # 模型参数列表
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(),
                                lr=args.base_lr))  # 原来backbone网络 学习率 0.01
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(),
                 lr=args.base_lr * 10))  # 新加入预测网络 学习率 0.1
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)  # SGD优化器
    # 3.x 设置sync_bn from torch.nn.SyncBatchNorm
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    ## step.4 多线程分布式工作
    # 4.1 判断是否是在主进程中, 如果在进行如下程序
    if main_process():
        global logger, writer
        logger = get_logger()  # 设置logger
        writer = SummaryWriter(args.save_path)  # 设置writer
        logger.info(args)  # 输出参数列表
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)  # 输出网络列表
    # 4.2 分布式工作
    if args.distributed:
        torch.cuda.set_device(gpu)  # 指定编号为gpu的那一张显卡
        args.batch_size = int(args.batch_size /
                              ngpus_per_node)  # 每张卡的训练的batch size
        args.batch_size_val = int(args.batch_size_val /
                                  ngpus_per_node)  # 每张卡的评测的batch size
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)  # 每张卡工作的数目
        model = torch.nn.parallel.DistributedDataParallel(
            model.cuda(), device_ids=[gpu])  # 加载torch分布式
    else:
        model = torch.nn.DataParallel(model.cuda())  # 数据并行

    ## step.5 加载网络权重
    # 5.1 直接加载网络预权重
    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> no weight found at '{}'".format(args.weight))
    # 5.2 加载上次没训练完的模型权重
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    ## step.7 设置数据loader
    # 7.1 loader参数设置
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])  # 组合数据预处理

    # 7.2 训练数据, 可以根据需要自己修改或写
    # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------#
    train_data = dataset.SemData(split='train',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform)
    # ---------------------------------------------------- END ---------------------------------------------------#
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)  # 分布式下数据loader
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)
    if args.evaluate:  # evaluate数据
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
        val_data = dataset.SemData(split='val',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    ## step.8 主循环
    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # 8.1 训练函数
        # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------#
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, optimizer, epoch)
        # ---------------------------------------------------- END ---------------------------------------------------#

        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
        # 8.2 保存checkpoint
        if (epoch_log % args.save_freq == 0) and main_process():
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(
                    epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        # 训练一个epoch之后evaluate
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Esempio n. 19
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss

    BatchNorm = nn.BatchNorm2d

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)

    model = PFENet(layers=args.layers, classes=2, zoom_factor=8, \
        criterion=nn.CrossEntropyLoss(ignore_index=255), BatchNorm=BatchNorm, \
        pretrained=True, shot=args.shot, ppm_scales=args.ppm_scales, vgg=args.vgg)

    global logger, writer
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))
    logger.info(model)
    print(args)

    model = torch.nn.DataParallel(model.cuda())

    if args.weight:
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    assert args.split in [0, 1, 2, 3, 999]

    if args.resized_val:
        val_transform = transform.Compose([
            transform.Resize(size=args.val_size),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
    else:
        val_transform = transform.Compose([
            transform.test_Resize(size=args.val_size),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
    val_data = dataset.SemData(split=args.split, shot=args.shot, data_root=args.data_root, \
                            data_list=args.val_list, transform=val_transform, mode='val', \
                            use_coco=args.use_coco, use_split_coco=args.use_split_coco)
    val_sampler = None
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size_val,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    loss_val, mIoU_val, mAcc_val, allAcc_val, class_miou = validate(
        val_loader, model, criterion)
Esempio n. 20
0
def main():
    global args, logger
    args = get_parser()
    # check(args)
    logger = get_logger()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(x) for x in args.test_gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    gray_folder = os.path.join(args.save_folder, 'gray')
    color_folder = os.path.join(args.save_folder, 'color')
    derain_folder = os.path.join(args.save_folder, 'derain')
    edge_folder = os.path.join(args.save_folder, 'edge')
    result_txt_path = os.path.join(args.save_folder, 'results.txt')

    test_transform = transform.Compose(
        [transform.ToTensor(),
         transform.Normalize(mean=mean, std=std)])
    test_data = dataset.SemData(split=args.split,
                                data_root=args.data_root,
                                rain_data_root=args.rain_data_root,
                                data_list=args.test_list,
                                transform=test_transform)
    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    colors = np.loadtxt(args.colors_path).astype('uint8')
    names = [line.rstrip('\n') for line in open(args.names_path)]

    if not args.has_prediction:
        if args.arch == 'iterative_derain_seg':
            from model.dic_arch_derainseg import DIC
            model = DIC(args, is_train=False)
        else:
            raise NotImplementedError
        logger.info(model)
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
        if os.path.isfile(args.model_path):
            logger.info("=> loading checkpoint '{}'".format(args.model_path))
            checkpoint = torch.load(args.model_path)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            logger.info("=> loaded checkpoint '{}'".format(args.model_path))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.model_path))
        test(test_loader, test_data.data_list, model, args.classes, mean, std,
             gray_folder, color_folder, derain_folder, edge_folder, colors)
    if args.split != 'test':
        cal_acc(test_data.data_list, gray_folder, derain_folder, args.classes,
                names, result_txt_path)
def main_worker(argss):
    global args
    args = argss

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)

    # 初始化模型
    model = FSSNet(layers=args.layers, classes=2, criterion=nn.CrossEntropyLoss(ignore_index=255),
                   pretrained=True, shot=args.shot, ppm_scales=args.ppm_scales, vgg=args.vgg, FPN=args.FPN)

    # 处理backbone
    optimizer = backbone_optimizer(model, args)

    global logger, writer
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))
    logger.info(model)
    print(args)
    # 并行计算
    model = model.cuda()
    # 加载模型参数,用以finetune或测试
    if args.weight:
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))
    # 加载模型继续训练
    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    # 归一化相关
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    assert args.split in [0, 1, 2, 999]

    # 设置训练transform,train data 和trainloader
    train_transform = [
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.padding_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.padding_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)]

    train_transform = transform.Compose(train_transform)
    train_data = dataset.SemData(split=args.split, shot=args.shot, normal=args.normal, data_root=args.data_root, \
                                data_list=args.train_list, nom_list=args.trainnom_list, transform=train_transform, mode='train')

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)

    # 设置测试transform,train data 和trainloader
    if args.evaluate:
        if args.resized_val:
            val_transform = transform.Compose([
                transform.Resize(size=args.val_size),
                transform.ToTensor(),
                transform.Normalize(mean=mean, std=std)])    
        else:
            val_transform = transform.Compose([
                transform.test_Resize(size=args.val_size),
                transform.ToTensor(),
                transform.Normalize(mean=mean, std=std)])           
        val_data = dataset.SemData(split=args.split, shot=args.shot, normal=args.normal, data_root=args.data_root, \
                                data_list=args.val_list,nom_list=args.valnom_list, transform=val_transform, mode='val')
        val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    max_iou = 0.
    max_fbiou = 0
    best_epoch = 0
    filename = 'FSSNet.pth'

    # 按epoch 进行训练和测试
    for epoch in range(args.start_epoch, args.epochs):
        # 设定测试时的随机种子
        if args.fix_random_seed_val:
            torch.cuda.manual_seed(args.manual_seed + epoch)
            np.random.seed(args.manual_seed + epoch)
            torch.manual_seed(args.manual_seed + epoch)
            torch.cuda.manual_seed_all(args.manual_seed + epoch)
            random.seed(args.manual_seed + epoch)

        epoch_log = epoch + 1
        # 训练
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
        # 测试
        if args.evaluate and (epoch % 2 == 0 or (args.epochs<=50 and epoch%1==0)):
            loss_val, mIoU_val, mAcc_val, allAcc_val, class_miou = validate(val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('class_miou_val', class_miou, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
            if class_miou > max_iou:
                max_iou = class_miou
                best_epoch = epoch
                if os.path.exists(filename):
                    os.remove(filename)            
                filename = args.save_path + '/train_epoch_' + str(epoch) + '_'+str(max_iou)+'.pth'
                logger.info('Saving checkpoint to: ' + filename)
                torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)
            if mIoU_val > max_fbiou :
                max_fbiou = mIoU_val

            logger.info('Best Epoch {:.1f} Best IoU {:.4f} Best FB-IoU {:.4f}'.format( best_epoch, max_iou, max_fbiou))

    filename = args.save_path + '/final.pth'
    logger.info('Saving checkpoint to: ' + filename)
    torch.save({'epoch': args.epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)