def load_val_dataset():
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    source_simul_transform = simul_transforms.Compose([
        simul_transforms.FreeScale(cfg.VAL.IMG_SIZE)
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(cfg.DATA.IGNORE_LABEL, cfg.DATA.NUM_CLASSES - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    print '='*50
    print 'Prepare Data...'
    val_set = CityScapes('val', list_filename = 'cityscapes_all.txt', simul_transform=source_simul_transform, \
                            transform=img_transform, target_transform=target_transform)
    target_loader = DataLoader(val_set, batch_size=cfg.VAL.IMG_BATCH_SIZE, num_workers=16, shuffle=True)

    return source_loader, target_loader, restore_transform
Beispiel #2
0
def main():
    net = U_Net(img_ch=1, num_classes=3).to(device)

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(384),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    center_crop = joint_transforms.CenterCrop(crop_size)
    train_input_transform = extended_transforms.ImgToTensor()

    target_transform = extended_transforms.MaskToTensor()
    make_dataset_fn = bladder.make_dataset_v2
    train_set = bladder.Bladder(data_path,
                                'train',
                                joint_transform=train_joint_transform,
                                center_crop=center_crop,
                                transform=train_input_transform,
                                target_transform=target_transform,
                                make_dataset_fn=make_dataset_fn)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

    if loss_name == 'dice_':
        criterion = SoftDiceLossV2(activation='sigmoid',
                                   num_classes=3).to(device)
    elif loss_name == 'bcew_':
        criterion = nn.BCEWithLogitsLoss().to(device)

    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    train(train_loader, net, criterion, optimizer, n_epoch, 0)
Beispiel #3
0
def main():
    net = AFENet(classes=19, pretrained_model_path=None).cuda()
    net.load_state_dict(
        torch.load(os.path.join(args['model_save_path'], args['snapshot'])))

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])
    dataset_path = args['dataset_path']

    test_set = cityscapes.CityScapes(dataset_path,
                                     'fine',
                                     'test',
                                     transform=input_transform,
                                     target_transform=target_transform,
                                     val_scale=args['scale'])
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             num_workers=1,
                             shuffle=False)
    test(test_loader, net, input_transform, restore_transform, args['scale'])
Beispiel #4
0
def __main__(args):
    #initializing pretrained network
    pspnet = PSPNet(n_classes=cityscapes.num_classes).cuda(gpu0)
    pspnet.load_pretrained_model(model_path=pspnet_path)
    #transformation and loading dataset
    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = standard_transforms.Compose(
        [extended_transforms.MaskToTensor()])

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])

    visualize = standard_transforms.ToTensor()
    val_set = cityscapes.CityScapes('val',
                                    transform=val_input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=False)
    validate(pspnet, val_loader, cityscapes.num_classes, args,
             restore_transform, visualize)
Beispiel #5
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    src_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
    ])
    tgt_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    src_dataset = GTA5DataSetLMDB(
        args.data_dir, args.data_list,
        joint_transform=joint_transform,
        transform=src_input_transform, 
        target_transform=target_transform,
    )
    src_loader = data.DataLoader(
        src_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True, drop_last=True
    )
    tgt_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        joint_transform=joint_transform,
        transform=tgt_input_transform, 
        target_transform=target_transform,
    )
    tgt_loader = data.DataLoader(
        tgt_dataset, batch_size=args.batch_size, shuffle=True, 
        num_workers=args.num_workers, pin_memory=True, drop_last=True
    )

    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_val, args.data_list_val,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, drop_last=False
    )

    style_trans = StyleTrans(args)
    style_trans.train(src_loader, tgt_loader, val_loader, writer)

    writer.close()
Beispiel #6
0
def main():
    net = fcn8s.FCN8s(num_classes=voc.num_classes,pretrained=False).cuda()
    #net = deeplab_resnet.Res_Deeplab().cuda()
    #net = dl.Res_Deeplab().cuda()
    
    #net.load_state_dict(torch.load(os.path.join(ckpt_path, args['exp_name'], args['snapshot'])))
    net.load_state_dict(torch.load(os.path.join(root,model,pth)))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    test_set = voc.VOC('eval', transform=val_input_transform,target_transform=target_transform)
    test_loader = DataLoader(test_set, batch_size=1, num_workers=8, shuffle=False)

    #check_mkdir(os.path.join(ckpt_path, args['exp_name'], 'test'))
    predictions = []
    masks = []
    ious =np.array([])
    for vi, data in enumerate(test_loader):
        img_name, img, msk = data
        img_name = img_name[0]
        
        H,W = img.size()[2:]
        L = min(H,W)
        interp_before = nn.UpsamplingBilinear2d(size=(L,L))
        interp_after = nn.UpsamplingBilinear2d(size=(H,W))
        
        img = Variable(img, volatile=True).cuda()
        msk = Variable(msk, volatile=True).cuda()
        masks.append(msk.data.squeeze_(0).cpu().numpy())
        
        #img = interp_before(img)
        output = net(img)
        #output = interp_after(output[3])

        prediction = output.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
        #prediction = output.data.max(1)[1].squeeze().cpu().numpy()
        ious = np.append(ious,get_iou(prediction,masks[-1]))
        
        predictions.append(prediction)
        ## prediction.save(os.path.join(ckpt_path, args['exp_name], 'test',
        img_name + '.png')) prediction = voc.colorize_mask(prediction)
        prediction.save(os.path.join(root,'segmented-images',model+'-'+img_name+'.png'))
        #if vi == 10:
        #    break
        print('%d / %d' % (vi + 1, len(test_loader)))
    results = evaluate(predictions,masks,voc.num_classes)
    print('mean iou = {}'.format(results[2]))
    print(ious.mean())
def load_dataset():
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if cfg.TRAIN.DATA_AUG:
        source_simul_transform = simul_transforms.Compose([
            simul_transforms.FreeScale(cfg.TRAIN.IMG_SIZE),
            simul_transforms.RandomHorizontallyFlip(),
            simul_transforms.PhotometricDistort()
        ])
        target_simul_transform = simul_transforms.Compose([
            simul_transforms.FreeScale(cfg.TRAIN.IMG_SIZE),
            simul_transforms.RandomHorizontallyFlip(),
            simul_transforms.PhotometricDistort()
        ])
    else:
        source_simul_transform = simul_transforms.Compose([
            simul_transforms.FreeScale(cfg.TRAIN.IMG_SIZE),
            simul_transforms.RandomHorizontallyFlip(),

        ])
        target_simul_transform = simul_transforms.Compose([
            simul_transforms.FreeScale(cfg.TRAIN.IMG_SIZE),
            simul_transforms.RandomHorizontallyFlip(),
        ])\

    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(cfg.DATA.IGNORE_LABEL, cfg.DATA.NUM_CLASSES - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    print '='*50
    print 'Prepare Data...'
    source_set = []
    if cfg.TRAIN.SOURCE_DOMAIN=='GTA5':
        source_set = GTA5('train', list_filename = 'GTA5_'+ cfg.DATA.SSD_GT + '.txt', simul_transform=source_simul_transform, transform=img_transform,
                           target_transform=target_transform)
    elif cfg.TRAIN.SOURCE_DOMAIN=='SYN':
    	source_set = SYN('train', list_filename = 'SYN_'+ cfg.DATA.SSD_GT + '.txt', simul_transform=source_simul_transform, transform=img_transform,
                           target_transform=target_transform)

    source_loader = DataLoader(source_set, batch_size=cfg.TRAIN.IMG_BATCH_SIZE, num_workers=16, shuffle=True, drop_last=True)
    
    target_set = CityScapes('train', list_filename = 'cityscapes_'+ cfg.DATA.SSD_GT + '.txt',simul_transform=target_simul_transform, transform=img_transform,
                         target_transform=target_transform)
    target_loader = DataLoader(target_set, batch_size=cfg.TRAIN.IMG_BATCH_SIZE, num_workers=16, shuffle=True, drop_last=True)

    return source_loader, target_loader, restore_transform
Beispiel #8
0
    def get_train_eval_dataloaders(self):
        input_transform = standard_transforms.Compose([
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*self.mean_std)
        ])
        target_transform = extended_transforms.MaskToTensor()

        # restore_transform = standard_transforms.Compose([
        #     extended_transforms.DeNormalize(*self.mean_std),
        #     standard_transforms.ToPILImage(),
        # ])
        #
        # visualize = standard_transforms.Compose([
        #     standard_transforms.Scale(400),
        #     standard_transforms.CenterCrop(400),
        #     standard_transforms.ToTensor()
        # ])

        train_set = VOC(self.config['data_path'],
                        'train',
                        transform=input_transform,
                        target_transform=target_transform,
                        rezise=self.config['resize_to'])
        train_loader = DataLoader(train_set,
                                  batch_size=self.config['batch_size'],
                                  num_workers=8,
                                  shuffle=True)

        val_set = VOC(self.config['data_path'],
                      'val',
                      transform=input_transform,
                      target_transform=target_transform,
                      rezise=self.config['resize_to'])
        val_loader = DataLoader(val_set,
                                batch_size=self.config['batch_size'],
                                num_workers=8,
                                shuffle=False)

        return train_loader, val_loader
Beispiel #9
0
def get_transforms(scale_size, input_size, region_size, supervised, test,
                   al_algorithm, full_res, dataset):
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if scale_size == 0:
        print('(Data loading) Not scaling the data')
        print('(Data loading) Random crops of ' + str(input_size) +
              ' in training')
        print('(Data loading) No crops in validation')
        if supervised:
            train_joint_transform = joint_transforms.Compose([
                joint_transforms.RandomCrop(input_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        else:
            train_joint_transform = joint_transforms.ComposeRegion([
                joint_transforms.RandomCropRegion(input_size,
                                                  region_size=region_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        if (not test and al_algorithm == 'ralis') and not full_res:
            val_joint_transform = joint_transforms.Scale(1024)
        else:
            val_joint_transform = None
        al_train_joint_transform = joint_transforms.ComposeRegion([
            joint_transforms.CropRegion(region_size, region_size=region_size),
            joint_transforms.RandomHorizontallyFlip()
        ])
    else:
        print('(Data loading) Scaling training data: ' + str(scale_size) +
              ' width dimension')
        print('(Data loading) Random crops of ' + str(input_size) +
              ' in training')
        print('(Data loading) No crops nor scale_size in validation')
        if supervised:
            train_joint_transform = joint_transforms.Compose([
                joint_transforms.Scale(scale_size),
                joint_transforms.RandomCrop(input_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        else:
            train_joint_transform = joint_transforms.ComposeRegion([
                joint_transforms.Scale(scale_size),
                joint_transforms.RandomCropRegion(input_size,
                                                  region_size=region_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        al_train_joint_transform = joint_transforms.ComposeRegion([
            joint_transforms.Scale(scale_size),
            joint_transforms.CropRegion(region_size, region_size=region_size),
            joint_transforms.RandomHorizontallyFlip()
        ])
        if dataset == 'gta_for_camvid':
            val_joint_transform = joint_transforms.ComposeRegion(
                [joint_transforms.Scale(scale_size)])
        else:
            val_joint_transform = None
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()

    return input_transform, target_transform, train_joint_transform, val_joint_transform, al_train_joint_transform
def main(train_args):
    net = PSPNet(num_classes=cityscapes.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse-extra)-psp_net', 'xx.pth')))
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_simul_transform = simul_transforms.Compose([
        simul_transforms.RandomSized(train_args['input_size']),
        simul_transforms.RandomRotate(10),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Scale(train_args['input_size'])
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('coarse', 'train', simul_transform=train_simul_transform,
                                      transform=train_input_transform,
                                      target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = cityscapes.CityScapes('coarse', 'val', simul_transform=val_simul_transform, transform=val_input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=train_args['val_batch_size'], num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=True, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=train_args['momentum'])

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, restore_transform, visualize)
Beispiel #11
0
def main():
    net = FCN32VGG(num_classes=mapillary.num_classes).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = mapillary.Mapillary('semantic',
                                    'training',
                                    joint_transform=train_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True,
                              pin_memory=True)
    val_set = mapillary.Mapillary('semantic',
                                  'validation',
                                  joint_transform=val_joint_transform,
                                  transform=input_transform,
                                  target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=False,
                            pin_memory=True)

    criterion = CrossEntropyLoss2d(size_average=False).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()).replace(':', '-') + '.txt'),
        'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=args['lr_patience'],
                                  min_lr=1e-10)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args,
                            restore_transform, visualize)
        scheduler.step(val_loss)

    torch.save(net.state_dict(), PATH)
def main(train_args):
    if cuda.is_available():
        net = fcn8s.FCN8s(num_classes=voc.num_classes, pretrained=False).cuda()
        #net = MBO.MBO().cuda()
        #net = deeplab_resnet.Res_Deeplab().cuda()
    else:
        print('cuda is not available')
        net = fcn8s.FCN8s(num_classes=voc.num_classes, pretrained=True)

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train',
                        set='benchmark',
                        transform=input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=bsz,
                              num_workers=8,
                              shuffle=True)

    val_set = voc.VOC('val',
                      set='voc',
                      transform=input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=4,
                            shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False,
                                   ignore_index=voc.ignore_label).cuda()
    optimizer = optim.Adam([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr']
    }],
                           betas=(train_args['momentum'], 0.999))
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=2,
                                  min_lr=1e-10,
                                  verbose=True)

    lr0 = 1e-7
    max_epoch = 50
    max_iter = max_epoch * len(train_loader)
    #optimizer = optim.SGD(net.parameters(),lr = lr0, momentum = 0.9, weight_decay = 0.0005)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

    log_dir = os.path.join(root, 'logs', 'voc-fcn')
    time = datetime.datetime.now().strftime('%d-%m-%H-%M')
    train_file = 'train_log' + time + '.txt'
    val_file = 'val_log' + time + '.txt'
    #os.makedirs(log_dir,exist_ok=True)

    training_log = open(os.path.join(log_dir, train_file), 'w')
    val_log = open(os.path.join(log_dir, val_file), 'w')

    curr_epoch = 1
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args,
              training_log, max_iter, lr0)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch,
                            train_args, restore_transform, visualize, val_log)

        scheduler.step(val_loss)

        lr_tmp = 0.0
        k = 0
        for param_group in optimizer.param_groups:
            lr_tmp += param_group['lr']
            k += 1
        val_log.write('learning rate = {}'.format(str(lr_tmp / k)) + '\n')
Beispiel #13
0
def main():
    net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                               'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                               'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform,
                                      transform=input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], betas=(args['momentum'], 0.999))

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
Beispiel #14
0
tmp = img[:,:,0]
img[:,:,0] = img[:,:,2]
img[:,:,2] = tmp
image=Image.fromarray(np.uint8(img)) 
mask = Image.open(mask_path)

mean_std = ([0.408, 0.457, 0.481], [1, 1, 1])

joint_transform_train = joint_transforms.Compose([
    joint_transforms.RandomCrop((321,321))
])

joint_transform_test = joint_transforms.Compose([
    joint_transforms.RandomCrop((512,512))
])

input_transform = standard_transforms.Compose([
    #standard_transforms.Resize((321,321)),
    #standard_transforms.RandomCrop(224),
    standard_transforms.ToTensor(),
    standard_transforms.Normalize(*mean_std)
])
target_transform = standard_transforms.Compose([
    #standard_transforms.Resize((224,224)),
    extended_transforms.MaskToTensor()
])

image = input_transform(image)*255
mask =  target_transform(mask)
img, mask = joint_transform_test(image, mask)
np.save('im.npy',img.numpy())
Beispiel #15
0
def main(train_args):
    print('No of classes', doc.num_classes)
    if train_args['network'] == 'psp':
        net = PSPNet(num_classes=doc.num_classes,
                     resnet=resnet,
                     res_path=res_path).cuda()
    elif train_args['network'] == 'mfcn':
        net = MFCN(num_classes=doc.num_classes, use_aux=True).cuda()
    elif train_args['network'] == 'psppen':
        net = PSPNet(num_classes=doc.num_classes,
                     resnet=resnet,
                     res_path=res_path).cuda()
    print("number of cuda devices = ", torch.cuda.device_count())
    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, 'model_' + exp_name,
                             train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net = torch.nn.DataParallel(net,
                                device_ids=list(
                                    range(torch.cuda.device_count())))
    net.train()
    mean_std = ([0.9584, 0.9588, 0.9586], [0.1246, 0.1223, 0.1224])
    weight = torch.FloatTensor(doc.num_classes)
    #weight[0] = 1.0/0.5511 # background
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.RandomSized(train_args['input_size']),
        simul_transforms.RandomRotate(3),
        #simul_transforms.RandomHorizontallyFlip()
        simul_transforms.Scale(train_args['input_size']),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    val_simul_transform = simul_transforms.Scale(train_args['input_size'])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.RandomGaussianBlur(),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    train_set = doc.DOC('train',
                        Dataroot,
                        joint_transform=train_simul_transform,
                        transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['train_batch_size'],
                              num_workers=1,
                              shuffle=True,
                              drop_last=True)
    train_loader_temp = DataLoader(train_set,
                                   batch_size=1,
                                   num_workers=1,
                                   shuffle=True,
                                   drop_last=True)
    train_args['No_train_images'] = len(train_loader_temp)
    del train_loader_temp
    if train_args['No_train_images'] == 87677:
        train_args['Type_of_train_image'] = 'All Synthetic slide image'
    elif train_args['No_train_images'] == 150:
        train_args['Type_of_train_image'] = 'Real image'
    elif train_args['No_train_images'] == 84641:
        train_args['Type_of_train_image'] = 'Synthetic image'
    elif train_args['No_train_images'] == 151641:
        train_args['Type_of_train_image'] = 'Real + Synthetic image'
    val_set = doc.DOC('val',
                      Dataroot,
                      joint_transform=val_simul_transform,
                      transform=val_input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            drop_last=True)
    #criterion = CrossEntropyLoss2d(weight = weight, size_average = True, ignore_index = doc.ignore_label).cuda()
    criterion = CrossEntropyLoss2d(size_average=True,
                                   ignore_index=doc.ignore_label).cuda()
    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=train_args['momentum'])
    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, 'model_' + exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']
    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')
    train(train_loader, net, criterion, optimizer, curr_epoch, train_args,
          val_loader)
Beispiel #16
0
def main():
    net = AFENet(classes=19,
                 pretrained_model_path=args['pretrained_model_path']).cuda()
    net_ori = [net.layer0, net.layer1, net.layer2, net.layer3, net.layer4]
    net_new = [
        net.ppm, net.cls, net.aux, net.ppm_reduce, net.aff1, net.aff2, net.aff3
    ]

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    dataset_path = args['dataset_path']

    # Loading dataset
    train_set = cityscapes.CityScapes(dataset_path,
                                      'fine',
                                      'train',
                                      transform=input_transform,
                                      target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=2,
                              shuffle=True)
    val_set = cityscapes.CityScapes(dataset_path,
                                    'fine',
                                    'val',
                                    transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=2,
                            shuffle=False)

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(args['model_save_path'],
                                    args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    params_list = []
    for module in net_ori:
        params_list.append(dict(params=module.parameters(), lr=args['lr']))
    for module in net_new:
        params_list.append(dict(params=module.parameters(),
                                lr=args['lr'] * 10))
    args['index_split'] = 5

    criterion = torch.nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label)

    optimizer = torch.optim.SGD(params_list,
                                lr=args['lr'],
                                momentum=args['momentum'],
                                weight_decay=args['weight_decay'])
    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(args['model_save_path'],
                             'opt_' + args['snapshot'])))

    check_makedirs(args['model_save_path'])

    all_iter = args['epoch_num'] * len(train_loader)

    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, optimizer, epoch, all_iter)
        validate(val_loader, net, criterion, optimizer, epoch,
                 restore_transform)
Beispiel #17
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    student_net = FCN8s(args.num_classes, args.model_path_prefix)
    student_net = torch.nn.DataParallel(student_net)

    student_net = student_net.cuda()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if '5' in args.data_dir:
        src_dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        src_dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    optimizer = optim.SGD(student_net.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    # optimizer = optim.Adam(
    #     student_net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    student_params = list(student_net.parameters())

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # src_criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    src_criterion = torch.nn.CrossEntropyLoss(ignore_index=255,
                                              size_average=False)

    num_batches = len(src_loader)
    highest = 0

    for epoch in range(args.num_epoch):

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            student_net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time() - tem_time)

            src_output = student_net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = src_criterion(src_output, src_label)
            cls_loss_value /= src_images.shape[0]

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}')

        miu = test_miou(student_net, tgt_val_loader, upsample,
                        './dataset/info.json')
        if miu > highest:
            torch.save(student_net.module.state_dict(),
                       osp.join(args.snapshot_dir, f'final_fcn.pth'))
            highest = miu
            print('>' * 50 + f'save highest with {miu:.2%}')
Beispiel #18
0
def main(train_args):
    net = FCN8s(num_classes=voc.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train',
                        transform=input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=1,
                              num_workers=4,
                              shuffle=True)
    val_set = voc.VOC('val',
                      transform=input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=4,
                            shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False,
                                   ignore_index=voc.ignore_label).cuda()

    optimizer = optim.Adam([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                           betas=(train_args['momentum'], 0.999))

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=train_args['lr_patience'],
                                  min_lr=1e-10,
                                  verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch,
                            train_args, restore_transform, visualize)
        scheduler.step(val_loss)
Beispiel #19
0
def train_with_correspondences(save_folder, startnet, args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network().to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)

    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        # needed since we slightly changed the structure of the network in
        # pspnet
        state_dict = rename_keys_to_match(state_dict)
        net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

    net.train()
    freeze_bn(net)

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    target_transform = extended_transforms.MaskToTensor()

    train_joint_transform_seg = joint_transforms.Compose([
        joint_transforms.Resize(1024),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomCrop(713)
    ])

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # keep list of segmentation loaders and validators
    seg_loaders = list()
    validators = list()

    # Correspondences
    corr_set = correspondences.Correspondences(
        corr_set_config.correspondence_path,
        corr_set_config.correspondence_im_path,
        input_size=(713, 713),
        mean_std=model_config.mean_std,
        input_transform=input_transform,
        joint_transform=train_joint_transform_corr)
    corr_loader = DataLoader(corr_set,
                             batch_size=args['train_batch_size'],
                             num_workers=args['n_workers'],
                             shuffle=True)

    # Cityscapes Training
    c_config = data_configs.CityscapesConfig()
    seg_set_cs = cityscapes.CityScapes(
        c_config.train_im_folder,
        c_config.train_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    seg_loader_cs = DataLoader(seg_set_cs,
                               batch_size=args['train_batch_size'],
                               num_workers=args['n_workers'],
                               shuffle=True)
    seg_loaders.append(seg_loader_cs)

    # Cityscapes Validation
    val_set_cs = cityscapes.CityScapes(
        c_config.val_im_folder,
        c_config.val_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    val_loader_cs = DataLoader(val_set_cs,
                               batch_size=1,
                               num_workers=args['n_workers'],
                               shuffle=False)
    validator_cs = Validator(val_loader_cs,
                             n_classes=c_config.n_classes,
                             save_snapshot=False,
                             extra_name_str='Cityscapes')
    validators.append(validator_cs)

    # Vistas Training and Validation
    if args['include_vistas']:
        v_config = data_configs.VistasConfig(
            use_subsampled_validation_set=True, use_cityscapes_classes=True)

        seg_set_vis = cityscapes.CityScapes(
            v_config.train_im_folder,
            v_config.train_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            joint_transform=train_joint_transform_seg,
            sliding_crop=None,
            transform=input_transform,
            target_transform=target_transform)
        seg_loader_vis = DataLoader(seg_set_vis,
                                    batch_size=args['train_batch_size'],
                                    num_workers=args['n_workers'],
                                    shuffle=True)
        seg_loaders.append(seg_loader_vis)

        val_set_vis = cityscapes.CityScapes(
            v_config.val_im_folder,
            v_config.val_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            sliding_crop=sliding_crop_im,
            transform=input_transform,
            target_transform=target_transform,
            transform_before_sliding=pre_validation_transform)
        val_loader_vis = DataLoader(val_set_vis,
                                    batch_size=1,
                                    num_workers=args['n_workers'],
                                    shuffle=False)
        validator_vis = Validator(val_loader_vis,
                                  n_classes=v_config.n_classes,
                                  save_snapshot=False,
                                  extra_name_str='Vistas')
        validators.append(validator_vis)
    else:
        seg_loader_vis = None
        map_validator = None

    # Extra Training
    extra_seg_set = cityscapes.CityScapes(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    extra_seg_loader = DataLoader(extra_seg_set,
                                  batch_size=args['train_batch_size'],
                                  num_workers=args['n_workers'],
                                  shuffle=True)
    seg_loaders.append(extra_seg_loader)

    # Extra Validation
    extra_val_set = cityscapes.CityScapes(
        corr_set_config.val_im_folder,
        corr_set_config.val_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    extra_val_loader = DataLoader(extra_val_set,
                                  batch_size=1,
                                  num_workers=args['n_workers'],
                                  shuffle=False)
    extra_validator = Validator(extra_val_loader,
                                n_classes=corr_set_config.n_classes,
                                save_snapshot=True,
                                extra_name_str='Extra')
    validators.append(extra_validator)

    # Loss setup
    if args['corr_loss_type'] == 'class':
        corr_loss_fct = CorrClassLoss(input_size=[713, 713])
    else:
        corr_loss_fct = FeatureLoss(
            input_size=[713, 713],
            loss_type=args['corr_loss_type'],
            feat_dist_threshold_match=args['feat_dist_threshold_match'],
            feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'],
            n_not_matching=0)

    seg_loss_fct = torch.nn.CrossEntropyLoss(
        reduction='elementwise_mean',
        ignore_index=cityscapes.ignore_label).to(device)

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    ##########################################################################
    #
    #       MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS
    #
    ##########################################################################
    softm = torch.nn.Softmax2d()

    val_iter = 0
    train_corr_loss = AverageMeter()
    train_seg_cs_loss = AverageMeter()
    train_seg_extra_loss = AverageMeter()
    train_seg_vis_loss = AverageMeter()

    seg_loss_meters = list()
    seg_loss_meters.append(train_seg_cs_loss)
    if args['include_vistas']:
        seg_loss_meters.append(train_seg_vis_loss)
    seg_loss_meters.append(train_seg_extra_loss)

    curr_iter = start_iter

    for i in range(args['max_iter']):
        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        #######################################################################
        #       SEGMENTATION UPDATE STEP
        #######################################################################
        #
        for si, seg_loader in enumerate(seg_loaders):
            # get segmentation training sample
            inputs, gts = next(iter(seg_loader))

            slice_batch_pixel_size = inputs.size(0) * inputs.size(
                2) * inputs.size(3)

            inputs = inputs.to(device)
            gts = gts.to(device)

            optimizer.zero_grad()
            outputs, aux = net(inputs)

            main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts)
            aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts)
            loss = main_loss + 0.4 * aux_loss

            loss.backward()
            optimizer.step()

            seg_loss_meters[si].update(main_loss.item(),
                                       slice_batch_pixel_size)

        #######################################################################
        #       CORRESPONDENCE UPDATE STEP
        #######################################################################
        if args['corr_loss_weight'] > 0 and args[
                'n_iterations_before_corr_loss'] < curr_iter:
            img_ref, img_other, pts_ref, pts_other, weights = next(
                iter(corr_loader))

            # Transfer data to device
            # img_ref is from the "good" sequence with generally better
            # segmentation results
            img_ref = img_ref.to(device)
            img_other = img_other.to(device)
            pts_ref = [p.to(device) for p in pts_ref]
            pts_other = [p.to(device) for p in pts_other]
            weights = [w.to(device) for w in weights]

            # Forward pass
            if args['corr_loss_type'] == 'hingeF':  # Works on features
                net.output_all = True
                with torch.no_grad():
                    output_feat_ref, aux_feat_ref, output_ref, aux_ref = net(
                        img_ref)
                output_feat_other, aux_feat_other, output_other, aux_other = net(
                    img_other
                )  # output1 must be last to backpropagate derivative correctly
                net.output_all = False

            else:  # Works on class probs
                with torch.no_grad():
                    output_ref, aux_ref = net(img_ref)
                    if args['corr_loss_type'] != 'hingeF' and args[
                            'corr_loss_type'] != 'hingeC':
                        output_ref = softm(output_ref)
                        aux_ref = softm(aux_ref)

                # output1 must be last to backpropagate derivative correctly
                output_other, aux_other = net(img_other)
                if args['corr_loss_type'] != 'hingeF' and args[
                        'corr_loss_type'] != 'hingeC':
                    output_other = softm(output_other)
                    aux_other = softm(aux_other)

            # Correspondence filtering
            pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample(
                output_ref,
                output_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig)
                if b.item() > 0
            ]
            pts_other_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig)
                if b.item() > 0
            ]
            weights_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, weights_orig)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                output_vals_ref = output_feat_ref[batch_inds_to_keep_orig]
                output_vals_other = output_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                output_vals_ref = output_ref[batch_inds_to_keep_orig]
                output_vals_other = output_other[batch_inds_to_keep_orig]

            pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample(
                aux_ref,
                aux_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux)
                if b.item() > 0
            ]
            pts_other_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux)
                if b.item() > 0
            ]
            weights_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, weights_aux)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig]
                aux_vals_other = aux_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                aux_vals_ref = aux_ref[batch_inds_to_keep_aux]
                aux_vals_other = aux_other[batch_inds_to_keep_aux]

            optimizer.zero_grad()

            # correspondence loss
            if output_vals_ref.size(0) > 0:
                loss_corr_hr = corr_loss_fct(output_vals_ref,
                                             output_vals_other, pts_ref_orig,
                                             pts_other_orig, weights_orig)
            else:
                loss_corr_hr = 0 * output_vals_other.sum()

            if aux_vals_ref.size(0) > 0:
                loss_corr_aux = corr_loss_fct(
                    aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux,
                    weights_aux)  # use output from img1 as "reference"
            else:
                loss_corr_aux = 0 * aux_vals_other.sum()

            loss_corr = args['corr_loss_weight'] * \
                (loss_corr_hr + 0.4 * loss_corr_aux)
            loss_corr.backward()

            optimizer.step()
            train_corr_loss.update(loss_corr.item())

        #######################################################################
        #       LOGGING ETC
        #######################################################################
        curr_iter += 1
        val_iter += 1

        writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg,
                          curr_iter)
        writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter)
        writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter)

        if (i + 1) % args['print_freq'] == 0:
            str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % (
                curr_iter, len(corr_loader), train_corr_loss.avg,
                train_seg_cs_loss.avg, train_seg_vis_loss.avg,
                train_seg_extra_loss.avg, optimizer.param_groups[1]['lr'])
            print(str2write)
            f_handle.write(str2write + "\n")

        if val_iter >= args['val_interval']:
            val_iter = 0
            for validator in validators:
                validator.run(net,
                              optimizer,
                              args,
                              curr_iter,
                              save_folder,
                              f_handle,
                              writer=writer)

    # Post training
    f_handle.close()
    writer.close()
Beispiel #20
0
def main():
    # epoch = 100
    # info = "ATONet_final3_loss3_5_BN_batch=4_use_ohem=0_bins=8_4_2epoch=100"
    # snapshot = "epoch_98_loss_0.12540_acc_0.95847_acc-cls_0.78683_mean-iu_0.70210_fwavacc_0.92424_lr_0.0000453781.pth"
    # epoch=200
    info = "ATONet_final3_loss3_5_BN_batch=4_use_ohem=False_bins=8_4_2epoch=200"
    snapshot = "epoch_193_loss_0.11953_acc_0.96058_acc-cls_0.79683_mean-iu_0.71272_fwavacc_0.92781_lr_0.0000798490.pth"

    model_save_path = './save_models/cityscapes/{}'.format(info)
    print(model_save_path)

    net = ATONet(classes=19, bins=(8, 4, 2), use_ohem=False).cuda()

    net.load_state_dict(torch.load(os.path.join(model_save_path, snapshot)))

    net.eval()
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    #
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    root = '/titan_data1/caokuntao/data/cityscapes'

    test_set = cityscapes.CityScapes(root,
                                     'fine',
                                     'test',
                                     transform=input_transform,
                                     target_transform=target_transform)
    test_loader = DataLoader(test_set,
                             batch_size=args['test_batch_size'],
                             num_workers=4,
                             shuffle=False)

    if not os.path.exists(model_save_path):
        os.mkdir(model_save_path)

    trainid_to_id = {
        0: 7,
        1: 8,
        2: 11,
        3: 12,
        4: 13,
        5: 17,
        6: 19,
        7: 20,
        8: 21,
        9: 22,
        10: 23,
        11: 24,
        12: 25,
        13: 26,
        14: 27,
        15: 28,
        16: 31,
        17: 32,
        18: 33
    }

    net.eval()

    gts_all, predictions_all, img_name_all = [], [], []
    with torch.no_grad():
        for vi, data in enumerate(test_loader):
            inputs, img_name = data
            N = inputs.size(0)
            inputs = Variable(inputs).cuda()

            outputs = net(inputs)
            predictions = outputs.data.max(1)[1].squeeze_(1).cpu().numpy()

            predictions_all.append(predictions)
            img_name_all.append(img_name)

        print('done')
        predictions_all = np.concatenate(predictions_all)
        img_name_all = np.concatenate(img_name_all)

        to_save_dir = os.path.join(model_save_path, exp_file)
        if not os.path.exists(to_save_dir):
            os.mkdir(to_save_dir)

        for idx, data in enumerate(zip(img_name_all, predictions_all)):
            if data[0] is None:
                continue
            img_name = data[0]
            pred = data[1]
            pred_copy = pred.copy()
            for k, v in trainid_to_id.items():
                pred_copy[pred == k] = v
            pred = Image.fromarray(pred_copy.astype(np.uint8))
            pred.save(os.path.join(to_save_dir, img_name))
def main(train_args):
    import pdb
    pdb.set_trace()
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    net = FCN8s(num_classes=plant.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    net.train()

    mean_std = ([0.385, 0.431, 0.452], [0.289, 0.294, 0.285])

    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(500),
        standard_transforms.CenterCrop(500),
        standard_transforms.ToTensor()
    ])

    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    train_set = plant.Plant('train',
                            augmentations=data_aug,
                            transform=input_transform,
                            target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=4,
                              num_workers=2,
                              shuffle=True)
    val_set = plant.Plant('val',
                          transform=input_transform,
                          target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=2,
                            shuffle=False)

    weights = torch.FloatTensor(cfg.train_weights)
    criterion = CrossEntropyLoss2d(weight=weights,
                                   size_average=False,
                                   ignore_index=plant.ignore_label).cuda()

    optimizer = optim.Adam([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                           betas=(train_args['momentum'], 0.999))

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')

    #train_args['best_record']['mean_iu'] = 0.50
    #curr_epoch = 100
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=train_args['lr_patience'],
                                  min_lr=1e-10,
                                  verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        val_loss = validate(val_loader, net, criterion, optimizer, epoch,
                            train_args, restore_transform, visualize)
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        scheduler.step(val_loss)
def main(train_args):
    backbone = ResNet()
    backbone.load_state_dict(torch.load(
        './weight/resnet34-333f7ec4.pth'), strict=False)
    net = Decoder34(num_classes=13, backbone=backbone).cuda()
    D = discriminator(input_channels=16).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(
            ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()
    D.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = wp.Wp('train', transform=input_transform,
                      target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=4,
                              num_workers=4, shuffle=True)
    # val_set = wp.Wp('val', transform=input_transform,
    #                 target_transform=target_transform)
    # XR:所以这里本来就不能用到val?这里为什么不用一个val的数据集呢?
    val_loader = DataLoader(train_set, batch_size=1,
                            num_workers=4, shuffle=False)
    criterion = DiceLoss().cuda()
    criterion_D = nn.BCELoss().cuda()
    optimizer_AE = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], betas=(train_args['momentum'], 0.999))
    optimizer_D = optim.Adam([
        {'params': [param for name, param in D.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in D.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], betas=(train_args['momentum'], 0.999))

    if len(train_args['snapshot']) > 0:
        optimizer_AE.load_state_dict(torch.load(os.path.join(
            ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer_AE.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer_AE.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) +
                      '.txt'), 'w').write(str(train_args) + '\n\n')

    scheduler = ReduceLROnPlateau(
        optimizer_AE, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, D, criterion, criterion_D, optimizer_AE,
              optimizer_D, epoch, train_args)
        val_loss = validate(val_loader, net, criterion, optimizer_AE,
                            epoch, train_args, restore_transform, visualize)
        scheduler.step(val_loss)
Beispiel #23
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    loader = data.DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform,
        target_transform=target_transform)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear',
                           align_corners=True)

    net = resnet101_ibn_a_deeplab(args.model_path_prefix,
                                  n_classes=args.n_classes)
    # optimizer = get_seg_optimizer(net, args)
    optimizer = torch.optim.SGD(net.parameters(), args.learning_rate,
                                args.momentum)
    net = torch.nn.DataParallel(net)
    criterion = torch.nn.CrossEntropyLoss(size_average=False,
                                          ignore_index=args.ignore_index)

    num_batches = len(loader)
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            show_fig = (batch_index + 1) % args.show_img_freq == 0
            iteration = batch_index + 1 + epoch * num_batches

            # poly_lr_scheduler(
            #     optimizer=optimizer,
            #     init_lr=args.learning_rate,
            #     iter=iteration - 1,
            #     lr_decay_iter=args.lr_decay,
            #     max_iter=args.num_epoch*num_batches,
            #     power=args.poly_power,
            # )

            net.train()
            # net.module.freeze_bn()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time() - tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}')
            if show_fig:
                base_lr = optimizer.param_groups[0]["lr"]
                output = torch.argmax(output, dim=1).detach()[0, ...].cpu()
                fig, axes = plt.subplots(2, 1, figsize=(12, 14))
                axes = axes.flat
                axes[0].imshow(colorize_mask(output.numpy()))
                axes[0].set_title(name[0])
                axes[1].imshow(colorize_mask(label[0, ...].numpy()))
                axes[1].set_title(f'seg_true_{base_lr:.6f}')
                writer.add_figure('A_seg', fig, iteration)

        mean_iu = test_miou(net, val_loader, upsample,
                            './ae_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix,
                         f'{epoch:d}_{mean_iu*100:.0f}.pth'))

    writer.close()
Beispiel #24
0
        return len(self.imgs)


if __name__ == "__main__":
    # make_dataset(root, mode)
    joint_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(128)
    ])
    # transform = transforms.Compose([
    #     transforms.ToTensor(),
    #     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    # ])
    transform = extended_transforms.ImgToTensor()
    target_transform = extended_transforms.MaskToTensor()
    make_dataset_fn = make_dataset_dcm
    root = './hospital_data/MRI_T2'
    mode = 'val'
    batch_size = 2
    dataset = Bladder(root,
                      mode,
                      make_dataset_fn=make_dataset_fn,
                      joint_transform=joint_transform,
                      transform=transform,
                      target_transform=target_transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for x in data_loader:
        print(type(x))
        # break
Beispiel #25
0
def testing(net, train_args, curr_iter):
    net.eval()

    mean_std = ([0.9584, 0.9588, 0.9586], [0.1246, 0.1223, 0.1224])

    test_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    #target_transform = extended_transforms.MaskToTensor()
    scaleTest_transform = simul_transforms.Scale(512)
    test_simul_transform = simul_transforms.Scale(train_args['input_size'])
    #val_input_transform = standard_transforms.Compose([
    #    standard_transforms.ToTensor(),
    #    standard_transforms.Normalize(*mean_std)
    #test_set = doc.DOC('val',Dataroot, transform=val_input_transform,
    #                 target_transform=target_transform)

    # segmentation on test images
    test_set = doc.DOC('test',
                       ckpt_path,
                       joint_transform=test_simul_transform,
                       transform=test_input_transform)
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             num_workers=1,
                             shuffle=False)

    check_mkdir(os.path.join(ckpt_path, exp_name, str(curr_iter)))
    listOfImgs = [
        '2018_101008181.jpg', '2018_101001826.jpg', '2016_101000408.jpg',
        '2009_052067571.jpg', '2010_055401143.jpg'
    ]
    for vi, data in enumerate(test_loader):
        img_name, img = data
        img_name = img_name[0]
        with torch.no_grad():
            img = img.cuda()
            output = net(img)

            prediction = output.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
            imput_img = img.squeeze_(0).cpu().numpy()
            prediction = doc.colorize_mask_combine(
                prediction, ckpt_path + 'data/img/' + img_name)
            #if img_name in listOfImgs:
            if curr_iter > 1:
                prediction.save(
                    os.path.join(ckpt_path, exp_name, str(curr_iter),
                                 img_name))
    call([
        "rsync", "-avz",
        os.path.join(ckpt_path, exp_name),
        "[email protected]:/home/jobinkv/Documents/r1/19wavc/"
    ])
    #print '%d / %d' % (vi + 1, len(test_loader))

    test_set_eva = doc.DOC('test_eva',
                           ckpt_path,
                           joint_transform=test_simul_transform,
                           transform=test_input_transform,
                           target_transform=target_transform)
    test_loader_eva = DataLoader(test_set_eva,
                                 batch_size=1,
                                 num_workers=1,
                                 shuffle=False)

    gts_all, predictions_all = [], []

    for vi, data in enumerate(test_loader_eva):
        img, gts = data
        with torch.no_grad():
            img = img.cuda()
            gts = gts.cuda()

            output = net(img)
            prediction = output.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

            gts_all.append(gts.squeeze_(0).cpu().numpy())
            predictions_all.append(prediction)

    acc, acc_cls, mean_iu, fwavacc, sep_iu = evaluate(predictions_all, gts_all,
                                                      doc.num_classes)
    del predictions_all
    print(
        '--------------------------------------------------------------------')
    print(
        '[test acc %.5f], [test acc_cls %.5f], [test mean_iu %.5f], [test fwavacc %.5f]'
        % (acc, acc_cls, mean_iu, fwavacc))
    print(
        '--------------------------------------------------------------------')
    sep_iu = sep_iu.tolist()
    sep_iu.append(mean_iu)
    sep_iu.insert(0, curr_iter)
    sep_iou_test.append(sep_iu)
Beispiel #26
0
def main(network,
         train_batch_size=4,
         val_batch_size=4,
         epoch_num=50,
         lr=2e-2,
         weight_decay=1e-4,
         momentum=0.9,
         factor=10,
         val_scale=None,
         model_save_path='./save_models/cityscapes',
         data_type='Cityscapes',
         snapshot='',
         accumulation_steps=1):

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    # Loading dataset
    if data_type == 'Cityscapes':
        # dataset_path = '/home/caokuntao/data/cityscapes'
        # dataset_path = '/titan_data2/ckt/datasets/cityscapes'  # 23341
        dataset_path = '/titan_data1/caokuntao/data/cityscapes'  # 新的23341
        train_set = cityscapes.CityScapes(dataset_path,
                                          'fine',
                                          'train',
                                          transform=input_transform,
                                          target_transform=target_transform)
        val_set = cityscapes.CityScapes(dataset_path,
                                        'fine',
                                        'val',
                                        val_scale=val_scale,
                                        transform=input_transform,
                                        target_transform=target_transform)
    else:
        dataset_path = '/home/caokuntao/data/camvid'
        # dataset_path = '/titan_data1/caokuntao/data/camvid'  # 新的23341
        train_set = camvid.Camvid(dataset_path,
                                  'train',
                                  transform=input_transform,
                                  target_transform=target_transform)
        val_set = camvid.Camvid(dataset_path,
                                'test',
                                val_scale=val_scale,
                                transform=input_transform,
                                target_transform=target_transform)

    train_loader = DataLoader(train_set,
                              batch_size=train_batch_size,
                              num_workers=train_batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=val_batch_size,
                            num_workers=val_batch_size,
                            shuffle=False)

    if len(snapshot) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        logger.info('training resumes from ' + snapshot)
        network.load_state_dict(
            torch.load(os.path.join(model_save_path, snapshot)))
        split_snapshot = snapshot.split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    criterion = torch.nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label)

    paras = dict(network.named_parameters())
    paras_new = []
    for k, v in paras.items():
        if 'layer' in k and ('conv' in k or 'downsample.0' in k):
            paras_new += [{
                'params': [v],
                'lr': lr / factor,
                'weight_decay': weight_decay / factor
            }]
        else:
            paras_new += [{
                'params': [v],
                'lr': lr,
                'weight_decay': weight_decay
            }]

    optimizer = torch.optim.SGD(paras_new, momentum=momentum)
    lr_sheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                             epoch_num,
                                                             eta_min=1e-6)

    # if len(snapshot) > 0:
    #     optimizer.load_state_dict(torch.load(os.path.join(model_save_path, 'opt_' + snapshot)))

    check_makedirs(model_save_path)

    all_iter = epoch_num * len(train_loader)

    #
    # validate(val_loader, network, criterion, optimizer, curr_epoch, restore_transform, model_save_path)
    # return

    for epoch in range(curr_epoch, epoch_num + 1):
        train(train_loader, network, optimizer, epoch, all_iter,
              accumulation_steps)
        validate(val_loader, network, criterion, optimizer, epoch,
                 restore_transform, model_save_path)
        lr_sheduler.step()

    # 1024 x 2048
    # dataset_path = '/titan_data1/caokuntao/data/cityscapes'  # 新的23341
    # val_set = cityscapes.CityScapes(dataset_path, 'fine', 'val', val_scale=True, transform=input_transform,
    #                                 target_transform=target_transform)
    # val_loader = DataLoader(val_set, batch_size=val_batch_size, num_workers=val_batch_size, shuffle=False)
    # validate(val_loader, network, criterion, optimizer, epoch, restore_transform, model_save_path)

    return
    # # cityscapes
    # val_set = val_cityscapes(dataset_path, 'fine', 'val')
    # val_loader = DataLoader(val_set, batch_size=1, num_workers=4, shuffle=False)
    n = len(val_loader)
    device = torch.device('cuda')
    net.eval()
    with torch.no_grad():
        # torch.cuda.synchronize()
        time_all = 0
        for vi, inputs in enumerate(val_loader):
            inputs = inputs[0].to(device)
            t0 = 1000 * time.time()
            outputs = net(inputs)
            torch.cuda.synchronize()
            t1 = 1000 * time.time()
            time_all = time_all + t1 - t0
            # predictions = outputs.data.max(1)[1].squeeze_(1).cpu()
            # torch.cuda.synchronize()

        fps = (1000 * n) / time_all
        # 每秒多少张
        print(fps)
Beispiel #27
0
def main():
    net = FCN8ResNet(num_classes=num_classes).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1])
        train_record['best_val_loss'] = float(split_snapshot[3])
        train_record['corr_mean_iu'] = float(split_snapshot[6])
        train_record['corr_epoch'] = curr_epoch

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(ignored_label, num_classes - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train',
                           simul_transform=train_simul_transform,
                           transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['batch_size'],
                              num_workers=16,
                              shuffle=True)
    val_set = CityScapes('val',
                         simul_transform=val_simul_transform,
                         transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=val_args['batch_size'],
                            num_workers=16,
                            shuffle=False)

    weight = torch.ones(num_classes)
    weight[num_classes - 1] = 0
    criterion = CrossEntropyLoss2d(weight).cuda()

    # don't use weight_decay for bias
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'fconv' in name
        ],
        'lr':
        2 * train_args['new_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'fconv' in name
        ],
        'lr':
        train_args['new_lr'],
        'weight_decay':
        train_args['weight_decay']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'fconv' not in name
        ],
        'lr':
        2 * train_args['pretrained_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'fconv' not in name
        ],
        'lr':
        train_args['pretrained_lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=0.9,
                          nesterov=True)

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['new_lr']
        optimizer.param_groups[1]['lr'] = train_args['new_lr']
        optimizer.param_groups[2]['lr'] = 2 * train_args['pretrained_lr']
        optimizer.param_groups[3]['lr'] = train_args['pretrained_lr']

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)
    if not os.path.exists(os.path.join(ckpt_path, exp_name)):
        os.mkdir(os.path.join(ckpt_path, exp_name))

    for epoch in range(curr_epoch, train_args['epoch_num']):
        train(train_loader, net, criterion, optimizer, epoch)
        validate(val_loader, net, criterion, optimizer, epoch,
                 restore_transform)
Beispiel #28
0
def main(train_args):
    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            torch.nn.init.normal(m.weight.data, mean=0, std=0.01)
            torch.nn.init.constant(m.bias.data, 0)

    net = VGG(num_classes=VOC.num_classes)
    net.apply(weights_init)
    net_dict = net.state_dict()
    pretrain = torch.load('./vgg16_20M.pkl')

    pretrain_dict = pretrain.state_dict()
    pretrain_dict = {
        'features.' + k: v
        for k, v in pretrain_dict.items() if 'features.' + k in net_dict
    }

    net_dict.update(pretrain_dict)
    net.load_state_dict(net_dict)

    net = nn.DataParallel(net)
    net = net.cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    net.train()

    mean_std = ([0.408, 0.457, 0.481], [1, 1, 1])

    joint_transform_train = joint_transforms.Compose(
        [joint_transforms.RandomCrop((321, 321))])

    joint_transform_test = joint_transforms.Compose(
        [joint_transforms.RandomCrop((512, 512))])

    input_transform = standard_transforms.Compose([
        #standard_transforms.Resize((321,321)),
        #standard_transforms.RandomCrop(224),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        #standard_transforms.Resize((224,224)),
        extended_transforms.MaskToTensor()
    ])
    #target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Resize(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = VOC.VOC('train',
                        joint_transform=joint_transform_train,
                        transform=input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=20,
                              num_workers=4,
                              shuffle=True)
    val_set = VOC.VOC('val',
                      joint_transform=joint_transform_test,
                      transform=input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=4,
                            shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False,
                                   ignore_index=VOC.ignore_label).cuda()

    #optimizer = optim.SGD(net.parameters(), lr = train_args['lr'], momentum=0.9,weight_decay=train_args['weight_decay'])
    optimizer = optim.SGD(
        [{
            'params': [
                param for name, param in net.named_parameters()
                if name[-4:] == 'bias'
            ],
            'lr':
            2 * train_args['lr'],
            'momentum':
            train_args['momentum'],
            'weight_decay':
            0
        }, {
            'params': [
                param for name, param in net.named_parameters()
                if name[-4:] != 'bias'
            ],
            'lr':
            train_args['lr'],
            'momentum':
            train_args['momentum'],
            'weight_decay':
            train_args['weight_decay']
        }], {
            'params': [
                param for name, param in net.named_parameters()
                if name[-8:] == 'voc.bias'
            ],
            'lr':
            20 * train_args['lr'],
            'momentum':
            train_args['momentum'],
            'weight_decay':
            0
        }, {
            'params': [
                param for name, param in net.named_parameters()
                if name[-10:] != 'voc.weight'
            ],
            'lr':
            10 * train_args['lr'],
            'momentum':
            train_args['momentum'],
            'weight_decay':
            train_args['weight_decay']
        })

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')

    #scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True)
    scheduler = StepLR(optimizer, step_size=13, gamma=0.1)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch,
                            train_args, restore_transform, visualize)
        #scheduler.step(val_loss)
        scheduler.step()
Beispiel #29
0
def main():
    net = PSPNet(num_classes=voc.num_classes).cuda()

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    sliding_crop = joint_transforms.SlidingCrop(args['crop_size'],
                                                args['stride_rate'],
                                                voc.ignore_label)
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    visualize = standard_transforms.Compose([
        standard_transforms.Resize(args['val_img_display_size']),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train',
                        joint_transform=train_joint_transform,
                        sliding_crop=sliding_crop,
                        transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=1,
                              shuffle=True,
                              drop_last=True)
    val_set = voc.VOC('val',
                      transform=val_input_transform,
                      sliding_crop=sliding_crop,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            drop_last=True)

    criterion = CrossEntropyLoss2d(size_average=True,
                                   ignore_index=voc.ignore_label).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, args,
          val_loader, visualize)
Beispiel #30
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    loader = data.DataLoader(
        dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers, pin_memory=True
    )
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform, target_transform=target_transform
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size,
        shuffle=False, num_workers=args.num_workers, pin_memory=True
    )


    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear', align_corners=True)

    net = PSP(
        nclass = args.n_classes, backbone='resnet101', 
        root=args.model_path_prefix, norm_layer=BatchNorm2d,
    )

    params_list = [
        {'params': net.pretrained.parameters(), 'lr': args.learning_rate},
        {'params': net.head.parameters(), 'lr': args.learning_rate*10},
        {'params': net.auxlayer.parameters(), 'lr': args.learning_rate*10},
    ]
    optimizer = torch.optim.SGD(params_list,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion = SegmentationLosses(nclass=args.n_classes, aux=True, ignore_index=255)
    # criterion = SegmentationMultiLosses(nclass=args.n_classes, ignore_index=255)

    net = DataParallelModel(net).cuda()
    criterion = DataParallelCriterion(criterion).cuda()

    logger = utils.create_logger(args.tensorboard_log_dir, 'PSP_train')
    scheduler = utils.LR_Scheduler(args.lr_scheduler, args.learning_rate,
                                   args.num_epoch, len(loader), logger=logger,
                                   lr_step=args.lr_step)

    net_eval = Eval(net)

    num_batches = len(loader)
    best_pred = 0.0
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            scheduler(optimizer, batch_index, epoch, best_pred)
            show_fig = (batch_index+1) % args.show_img_freq == 0
            iteration = batch_index+1+epoch*num_batches

            net.train()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time()-tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time()-tem_time)
            tem_time = time.time()

            if (batch_index+1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}'
                )
            # if show_fig:
            #     # base_lr = optimizer.param_groups[0]["lr"]
            #     output = torch.argmax(output[0][0], dim=1).detach()[0, ...].cpu()
            #     # fig, axes = plt.subplots(2, 1, figsize=(12, 14))
            #     # axes = axes.flat
            #     # axes[0].imshow(colorize_mask(output.numpy()))
            #     # axes[0].set_title(name[0])
            #     # axes[1].imshow(colorize_mask(label[0, ...].numpy()))
            #     # axes[1].set_title(f'seg_true_{base_lr:.6f}')
            #     # writer.add_figure('A_seg', fig, iteration)
            #     output_mask = np.asarray(colorize_mask(output.numpy()))
            #     label = np.asarray(colorize_mask(label[0,...].numpy()))
            #     image_out = np.concatenate([output_mask, label])
            #     writer.add_image('A_seg', image_out, iteration)

        mean_iu = test_miou(net_eval, val_loader, upsample,
                            './style_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth')
        )

    writer.close()