Beispiel #1
0
def main():

    check_dir = './' + name

    if not os.path.exists(check_dir):
        os.mkdir(check_dir)

    # data
    train_loader = torch.utils.data.DataLoader(
        VOC(opt.train_dir, split='train',
            transform=transforms.Compose([transforms.Resize((img_size, img_size))]),
            mean=mean, std=std),
        batch_size=opt.b, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        VOC(opt.train_dir, split='val',
            transform=transforms.Compose([transforms.Resize((img_size, img_size))]),
            mean=mean, std=std),
        batch_size=opt.b / 2, shuffle=True, num_workers=4, pin_memory=True)

    # models
    criterion = nn.CrossEntropyLoss(ignore_index=-1).cuda()
    if opt.model == 'FCN':
        net = models.FCN(pretrained=True, c_output=n_classes, base=opt.base).cuda()
    else:
        net = getattr(models, opt.model)(pretrain=True, c_output=n_classes).cuda()
    optimizer = torch.optim.Adam([
        {'params': net.parameters(), 'lr': 1e-4}
    ])
    logs = {'best_ep': 0, 'best': 0}
    for epoch in range(opt.e):
        train(epoch, train_loader, optimizer, criterion, net, logs)
        miou, pacc = validate(val_loader, net, os.path.join(check_dir, 'results'),
                        os.path.join(opt.train_dir, 'SegmentationClass'))
        logs[epoch] = {'mIOU': miou, 'pixelAcc': pacc}
        if miou > logs['best']:
            logs['best'] = miou
            logs['best_ep'] = epoch
            torch.save(net.state_dict(), '%s/net.pth' % (check_dir))
            with open(os.path.join(check_dir, 'logs.json'), 'w') as outfile:
                json.dump(logs, outfile)
Beispiel #2
0
sal_val_loader = torch.utils.data.DataLoader(Saliency(sal_val_img_dir,
                                                      sal_val_gt_dir,
                                                      crop=None,
                                                      flip=False,
                                                      rotate=None,
                                                      size=image_size,
                                                      training=False),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)

voc_val_loader = torch.utils.data.DataLoader(VOC(voc_val_img_dir,
                                                 voc_val_gt_dir,
                                                 voc_val_split,
                                                 crop=None,
                                                 flip=False,
                                                 rotate=None,
                                                 size=image_size,
                                                 training=False),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)


def val_sal():
    net.eval()
    with torch.no_grad():
        for it, (img, gt, batch_name, WW,
                 HH) in tqdm(enumerate(sal_val_loader), desc='train'):
            img = (img.cuda() - mean) / std
Beispiel #3
0
                                                    size=opt.imageSize,
                                                    mean=opt.mean,
                                                    std=opt.std,
                                                    training=False),
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)

c_output = 21

voc_train_loader = torch.utils.data.DataLoader(VOC(voc_train_img_dir,
                                                   voc_train_gt_dir,
                                                   voc_train_split,
                                                   crop=0.9,
                                                   flip=True,
                                                   rotate=None,
                                                   size=opt.imageSize,
                                                   mean=opt.mean,
                                                   std=opt.std,
                                                   training=True),
                                               batch_size=opt.batchSize,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)

voc_val_loader = torch.utils.data.DataLoader(VOC(voc_val_img_dir,
                                                 voc_val_gt_dir,
                                                 voc_val_split,
                                                 crop=None,
                                                 flip=False,
                                                 rotate=None,
    Saliency(sal_train_img_dir, sal_train_gt_dir,
           crop=0.9, flip=True, rotate=10, size=image_size, training=True),
    batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False)

sal_val_loader = torch.utils.data.DataLoader(
    Saliency(sal_val_img_dir, sal_val_gt_dir,
           crop=None, flip=False, rotate=None, size=image_size, training=False), 
    batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False)

voc_train_loader = torch.utils.data.DataLoader(
    VOCSELF(voc_train_img_dir, voc_train_gt_dir, voc_train_plbl_dir, voc_train_split,
           crop=None, flip=False, rotate=None, size=image_size, training=True),
    batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False)

voc_val_loader = torch.utils.data.DataLoader(
    VOC(voc_val_img_dir, voc_val_gt_dir, voc_val_split,
           crop=None, flip=False, rotate=None, size=image_size, training=False),
    batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False)

def val_sal():
    net.eval()
    with torch.no_grad():
        for it, (img, gt, batch_name, WW, HH) in tqdm(enumerate(sal_val_loader), desc='train'):
            img = (img.cuda()-mean)/std
            pred_seg, v_sal,_ = net(img)
            pred_seg = torch.softmax(pred_seg, 1)
            bg = pred_seg[:, :1]
            fg = (pred_seg[:, 1:]*v_sal[:, 1:]).sum(1, keepdim=True)
            fg = fg.squeeze(1)
            fg = fg*255
            for n, name in enumerate(batch_name):
                msk =fg[n]
voc_train_gt_dir = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/SegmentationClassAug'%home

voc_val_img_dir = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/JPEGImages'%home
voc_val_gt_dir = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/SegmentationClass'%home

voc_train_split = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/ImageSets/Segmentation/argtrain.txt'%home
voc_val_split = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt'%home

label = "" # label of model parameters to load

c_output = 21


voc_train_loader = torch.utils.data.DataLoader(
    VOC(voc_train_img_dir, voc_train_gt_dir, voc_train_split,
           crop=None, flip=False, rotate=None, size=opt.imageSize,
           mean=opt.mean, std=opt.std, training=-1),
    batch_size=opt.batchSize, shuffle=False, num_workers=4, pin_memory=True)

voc_val_loader = torch.utils.data.DataLoader(
    VOC(voc_val_img_dir, voc_val_gt_dir, voc_val_split,
           crop=None, flip=False, rotate=None, size=opt.imageSize,
           mean=opt.mean, std=opt.std, training=False),
    batch_size=opt.batchSize, shuffle=True, num_workers=4, pin_memory=True)


def test(model):
    print("============================= TEST ============================")
    model.switch_to_eval()
    for i, (img, name, WW, HH) in tqdm(enumerate(voc_val_loader), desc='testing'):
        model.test(img, name, WW, HH)
Beispiel #6
0
train_img_dir = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/JPEGImages'%home
train_gt_dir = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/SegmentationClassAug'%home

val_img_dir = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/JPEGImages'%home
val_gt_dir = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/SegmentationClass'%home

train_split = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/ImageSets/Segmentation/argtrain.txt'%home
val_split = '%s/data/datasets/segmentation_Dataset/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt'%home

c_output = 21


train_loader = torch.utils.data.DataLoader(
    VOC(train_img_dir, train_gt_dir, train_split,
           crop=0.9, flip=True, rotate=None, size=opt.imageSize,
           mean=opt.mean, std=opt.std, training=True),
    batch_size=opt.batchSize, shuffle=True, num_workers=4, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    VOC(val_img_dir, val_gt_dir, val_split,
           crop=None, flip=False, rotate=None, size=opt.imageSize,
           mean=opt.mean, std=opt.std, training=False),
    batch_size=opt.batchSize, shuffle=True, num_workers=4, pin_memory=True)


def test(model):
    print("============================= TEST ============================")
    model.switch_to_eval()
    for i, (img, name, WW, HH) in tqdm(enumerate(val_loader), desc='testing'):
        model.test(img, name, WW, HH)
def main():
    training_batch_size = 8
    validation_batch_size = 8
    epoch_num = 200
    iter_freq_print_training_log = 50
    lr = 1e-4

    net = SegNet(pretrained=True, num_classes=num_classes).cuda()
    curr_epoch = 0

    # net = FCN8VGG(pretrained=False, num_classes=num_classes).cuda()
    # snapshot = 'epoch_41_validation_loss_2.1533_mean_iu_0.5225.pth'
    # net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot)))
    # split_res = snapshot.split('_')
    # curr_epoch = int(split_res[1])

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simultaneous_transform = SimultaneousCompose([
        SimultaneousRandomHorizontallyFlip(),
        SimultaneousRandomScale((0.9, 1.1)),
        SimultaneousRandomCrop((300, 500))
    ])
    train_transform = transforms.Compose([
        RandomGaussianBlur(),
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])
    val_simultaneous_transform = SimultaneousCompose(
        [SimultaneousScale((300, 500))])
    val_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    restore = transforms.Compose(
        [DeNormalize(*mean_std),
         transforms.ToPILImage()])

    train_set = VOC(train_path,
                    simultaneous_transform=train_simultaneous_transform,
                    transform=train_transform,
                    target_transform=MaskToTensor())
    train_loader = DataLoader(train_set,
                              batch_size=training_batch_size,
                              num_workers=8,
                              shuffle=True)
    val_set = VOC(val_path,
                  simultaneous_transform=val_simultaneous_transform,
                  transform=val_transform,
                  target_transform=MaskToTensor())
    val_loader = DataLoader(val_set,
                            batch_size=validation_batch_size,
                            num_workers=8)

    criterion = CrossEntropyLoss2d(ignored_label=ignored_label)
    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ]
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'weight_decay':
        5e-4
    }],
                          lr=lr,
                          momentum=0.9,
                          nesterov=True)

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)

    best = [1e9, -1, -1]  # [best_val_loss, best_mean_iu, best_epoch]

    for epoch in range(curr_epoch, epoch_num):
        train(train_loader, net, criterion, optimizer, epoch,
              iter_freq_print_training_log)
        if (epoch + 1) % 20 == 0:
            lr /= 3
            adjust_lr(optimizer, lr)
        validate(epoch, val_loader, net, criterion, restore, best)
Beispiel #8
0
sal_val_img_dir = '/home/zeng/data/datasets/saliency/ECSSD/images'
sal_val_gt_dir = '/home/zeng/data/datasets/saliency/ECSSD/masks'

sal_train_loader = torch.utils.data.DataLoader(
    Saliency(sal_train_img_dir, sal_train_gt_dir,
           crop=0.9, flip=True, rotate=10, size=image_size, training=True),
    batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False)

sal_val_loader = torch.utils.data.DataLoader(
    Saliency(sal_val_img_dir, sal_val_gt_dir,
           crop=None, flip=False, rotate=None, size=image_size, training=False), 
    batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False)

voc_train_loader = torch.utils.data.DataLoader(
    VOC(voc_train_img_dir, voc_train_gt_dir, voc_train_split,
           crop=0.9, flip=True, rotate=10, size=image_size, training=True),
    batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False)

voc_val_loader = torch.utils.data.DataLoader(
    VOC(voc_val_img_dir, voc_val_gt_dir, voc_val_split,
           crop=None, flip=False, rotate=None, size=image_size, training=False),
    batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False)

def val_sal():
    net.eval()
    with torch.no_grad():
        for it, (img, gt, batch_name, WW, HH) in tqdm(enumerate(sal_val_loader), desc='train'):
            img = (img.cuda()-mean)/std
            pred_seg, v_sal, _ = net(img)
            pred_seg = torch.softmax(pred_seg, 1)
            bg = pred_seg[:, :1]
Beispiel #9
0
net = JLSFCN(c_output).cuda()
net.load_state_dict(torch.load(path_save_checkpoints))

mean = torch.Tensor([0.485, 0.456, 0.406])[None, ..., None, None].cuda()
std = torch.Tensor([0.229, 0.224, 0.225])[None, ..., None, None].cuda()

voc_train_img_dir = '/home/zeng/data/datasets/segmentation/VOCdevkit/VOC2012/JPEGImages'
voc_train_gt_dir = '/home/zeng/data/datasets/segmentation/VOCdevkit/VOC2012/SegmentationClassAug'
voc_train_split = '/home/zeng/data/datasets/segmentation/VOCdevkit/VOC2012/ImageSets/Segmentation/trainaug.txt'

voc_loader = torch.utils.data.DataLoader(VOC(voc_train_img_dir,
                                             voc_train_gt_dir,
                                             voc_train_split,
                                             crop=0.9,
                                             flip=True,
                                             rotate=10,
                                             size=image_size,
                                             training=False,
                                             tproc=True),
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True)


def val_voc():
    net.eval()
    with torch.no_grad():
        for t in range(10):
            for it, (img, gt, batch_name, WW,
Beispiel #10
0
    parser.add_argument('--show_image',
                        '-s',
                        default=False,
                        action='store_true')
    parser.add_argument('--save_path', '-p', default=None)

    args = parser.parse_args()
    all_dev = parse_devices(args.devices)

    mp_ctx = mp.get_context('spawn')
    network = DFN(config.num_classes,
                  criterion=None,
                  aux_criterion=None,
                  alpha=config.aux_loss_alpha)
    data_setting = {
        'img_root': config.img_root_folder,
        'gt_root': config.gt_root_folder,
        'train_source': config.train_source,
        'eval_source': config.eval_source
    }
    dataset = VOC(data_setting, 'val', None)

    with torch.no_grad():
        segmentor = SegEvaluator(dataset, config.num_classes,
                                 config.image_mean, config.image_std, network,
                                 config.eval_scale_array, config.eval_flip,
                                 all_dev, args.verbose, args.save_path,
                                 args.show_image)
        segmentor.run(config.snapshot_dir, args.epochs, config.val_log_file,
                      config.link_val_log_file)