Пример #1
0
    def val_dataloader(self) -> DataLoader:
        """Cityscapes val set."""
        transforms = self.val_transforms or self._default_transforms()
        target_transforms = self.target_transforms or self._default_target_transforms(
        )

        dataset = Cityscapes(
            self.data_dir,
            split="val",
            target_type=self.target_type,
            mode=self.quality_mode,
            transform=transforms,
            target_transform=target_transforms,
            **self.extra_args,
        )

        loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=self.drop_last,
        )
        return loader
Пример #2
0
    def __init__(self, dataset_root, split, download=False, integrity_check=True):
        assert not download, 'Downloading of CityScapes is not implemented in torchvision'
        assert split in (SPLIT_TRAIN, SPLIT_VALID), f'Invalid split {split}'
        self.integrity_check = integrity_check

        self.ds = Cityscapes(dataset_root, split=split, mode='fine', target_type='semantic')

        self.sample_names = []
        self.sample_id_to_name, self.sample_name_to_id = {}, {}
        for i, path in enumerate(self.ds.images):
            name = self._sample_name(path)
            self.sample_names.append(name)
            self.sample_id_to_name[i] = name
            self.sample_name_to_id[name] = i

        self.transforms = None

        dir = os.path.dirname(__file__)
        path_points = os.path.join(dir, 'cityscapes_synthetic_clicks.json')
        with open(path_points, 'r') as f:
            self.ds_clicks = json.load(f)

        self._semseg_class_colors = [clsdesc.color for clsdesc in self.ds.classes if not clsdesc.ignore_in_eval]
        self._semseg_class_names = [clsdesc.name for clsdesc in self.ds.classes if not clsdesc.ignore_in_eval]
        self._id_2_trainid = {clsdesc.id: clsdesc.train_id for clsdesc in self.ds.classes}
        self._trainid_2_id = [clsdesc.id for clsdesc in self.ds.classes if not clsdesc.ignore_in_eval]
        self._semseg_class_histogram = self._compute_histogram()

        if integrity_check:
            n_samples = len(self.sample_names)
            assert n_samples == {SPLIT_TRAIN: 2975, SPLIT_VALID: 500}[split], f'Wrong number of samples {n_samples}'
            for i, name in enumerate(self.sample_names):
                assert name in self.ds_clicks, f'Sample {i} name {name} path {self.ds.images[i]} does not have a click'
            self.integrity_check = False
Пример #3
0
    def test_dataloader(self):
        """
        Cityscapes test set
        """
        transforms = self.test_transforms or self._default_transforms()
        target_transforms = self.target_transforms or self._default_target_transforms(
        )

        dataset = Cityscapes(self.data_dir,
                             split='test',
                             target_type=self.target_type,
                             mode=self.quality_mode,
                             transform=transforms,
                             target_transform=target_transforms,
                             **self.extra_args)
        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=False,
                            num_workers=self.num_workers,
                            drop_last=True,
                            pin_memory=True)
        return loader
Пример #4
0
    def train_dataloader(self) -> DataLoader:
        """
        Cityscapes train set
        """
        transforms = self.train_transforms or self._default_transforms()
        target_transforms = self.target_transforms or self._default_target_transforms(
        )

        dataset = Cityscapes(self.data_dir,
                             split='train',
                             target_type=self.target_type,
                             mode=self.quality_mode,
                             transform=transforms,
                             target_transform=target_transforms,
                             **self.extra_args)

        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=self.shuffle,
                            num_workers=self.num_workers,
                            drop_last=self.drop_last,
                            pin_memory=self.pin_memory)
        return loader
Пример #5
0
def train(save_path, checkpoint, data_root, batch_size, dataset):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    transform = transforms.Compose(
        [transforms.Resize((128, 128)),
         transforms.ToTensor()])
    target_transform = transforms.Compose(
        [transforms.Resize((128, 128)),
         ToTensor()])
    if dataset == 'cityscapes':
        train_data = Cityscapes(str(data_root),
                                split='train',
                                mode='fine',
                                target_type='semantic',
                                transform=transform,
                                target_transform=transform)
        eG = 35
        dG = [35, 35, 20, 14, 10, 4, 1]
        eC = 8
        dC = 280
        n_classes = len(Cityscapes.classes)
        update_lr = update_lr_default
        epoch = 200
    else:
        train_data = Deepfashion(str(data_root),
                                 split='train',
                                 transform=transform,
                                 target_transform=transform)
        n_classes = len(Deepfashion.eclasses)
        eG = 8
        eC = 64
        dG = [8, 8, 4, 4, 2, 2, 1]
        dC = 160
        update_lr = update_lr_deepfashion
        epoch = 100
    data_loader = torch.utils.data.DataLoader(train_data,
                                              batch_size=batch_size,
                                              num_workers=1)

    os.makedirs(save_path, exist_ok=True)

    n_channels = 3
    encoder = Encoder(n_classes * n_channels, C=eC, G=eG)
    decoder = Decoder(8 * eG, n_channels, n_classes, C=dC, Gs=dG)
    discriminator = Discriminator(n_classes + n_channels)
    vgg = Vgg19().eval()

    encoder = torch.nn.DataParallel(encoder)
    decoder = torch.nn.DataParallel(decoder)
    discriminator = torch.nn.DataParallel(discriminator)
    vgg = torch.nn.DataParallel(vgg)

    gen_opt = optim.Adam(list(encoder.parameters()) +
                         list(decoder.parameters()),
                         lr=0.0001,
                         betas=(0, 0.9))
    dis_opt = optim.Adam(discriminator.parameters(), lr=0.0004, betas=(0, 0.9))
    gen_scheduler = optim.lr_scheduler.LambdaLR(gen_opt, update_lr)
    dis_scheduler = optim.lr_scheduler.LambdaLR(gen_opt, update_lr)
    params = [
        'encoder', 'decoder', 'discriminator', 'gen_opt', 'dis_opt',
        'gen_scheduler', 'dis_scheduler'
    ]

    if os.path.exists(checkpoint):
        cp = torch.load(checkpoint)
        print(f'Load checkpoint: {checkpoint}')
        for param in params:
            eval(param).load_state_dict(cp[param])
        # encoder.load_state_dict(cp['encoder'])
        # decoder.load_state_dict(cp['decoder'])
        # discriminator.load_state_dict(cp['discriminator'])
        # gen_opt.load_state_dict(cp['gen_opt'])
        # dis_opt.load_state_dict(cp['dis_opt'])
        # gen_scheduler.load_state_dict(cp['gen_scheduler'])
        # dis_scheduler.load_state_dict(cp['dis_scheduler'])

    def to_device_optimizer(opt):
        for state in opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

    to_device_optimizer(gen_opt)
    to_device_optimizer(dis_opt)

    encoder = encoder.to(device)
    decoder = decoder.to(device)
    discriminator = discriminator.to(device)
    vgg = vgg.to(device)
    print(len(data_loader))
    for epoch in range(epoch):
        e_g_loss = []
        e_d_loss = []
        for i, batch in tqdm(enumerate(data_loader)):
            x, sem = batch
            x = x.to(device)
            sem = sem.to(device)
            sem = sem * 255.0
            sem = sem.long()
            s = split_class(x, sem, n_classes)
            sem_target = sem.clone()
            del sem
            sem = torch.zeros(x.size()[0],
                              n_classes,
                              sem_target.size()[2],
                              sem_target.size()[3],
                              device=x.device)
            sem.scatter_(1, sem_target, 1)
            s = s.detach()
            s = s.to(device)
            mu, sigma = encoder(s)
            z = mu + torch.exp(0.5 * sigma) * torch.rand(mu.size(),
                                                         device=mu.device)
            gen = decoder(z, sem)
            d_fake = discriminator(gen, sem)
            d_real = discriminator(x, sem)
            l1loss = nn.L1Loss()
            gen_opt.zero_grad()
            loss_gen = 0.5 * d_fake[0][-1].mean() + 0.5 * d_fake[1][-1].mean()
            loss_fm = sum([
                sum([l1loss(f, g) for f, g in zip(fs, rs)])
                for fs, rs in zip(d_fake, d_real)
            ]).mean()

            f_fake = vgg(gen)
            f_real = vgg(x)
            # loss_p = 1.0 / 32 * l1loss(f_fake.relu1_2, f_real.relu1_2) + \
            #     1.0 / 16 * l1loss(f_fake.relu2_2, f_real.relu2_2) + \
            #     1.0 / 8 * l1loss(f_fake.relu3_3, f_real.relu3_3) + \
            #     1.0 / 4 * l1loss(f_fake.relu4_3, f_real.relu4_3) + \
            #     l1loss(f_fake.relu5_3, f_real.relu5_3)
            loss_p = 1.0 / 32 * l1loss(f_fake[0], f_real[0]) + \
                1.0 / 16 * l1loss(f_fake[1], f_real[1]) + \
                1.0 / 8 * l1loss(f_fake[2], f_real[2]) + \
                1.0 / 4 * l1loss(f_fake[3], f_real[3]) + \
                l1loss(f_fake[4], f_real[4])
            loss_kl = -0.5 * torch.sum(1 + sigma - mu * mu - torch.exp(sigma))
            loss = loss_gen + 10.0 * loss_fm + 10.0 * loss_p + 0.05 * loss_kl
            loss.backward(retain_graph=True)
            gen_opt.step()

            dis_opt.zero_grad()
            loss_dis = torch.mean(-torch.mean(torch.min(d_real[0][-1] - 1, torch.zeros_like(d_real[0][-1]))) +
                                  -torch.mean(torch.min(-d_fake[0][-1] - 1, torch.zeros_like(d_fake[0][-1])))) + \
                                  torch.mean(-torch.mean(torch.min(d_real[1][-1] - 1, torch.zeros_like(d_real[1][-1]))) +
                                  -torch.mean(torch.min(-d_fake[1][-1] - 1, torch.zeros_like(d_fake[1][-1]))))
            loss_dis.backward()
            dis_opt.step()

            e_g_loss.append(loss.item())
            e_d_loss.append(loss_dis.item())
            #plt.imshow((gen.detach().cpu().numpy()[0]).transpose(1, 2, 0))
            #plt.pause(.01)
            #print(i, 'g_loss', e_g_loss[-1], 'd_loss', e_d_loss[-1])
            os.makedirs(save_path / str(epoch), exist_ok=True)

            Image.fromarray((gen.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                             255.0).astype(np.uint8)).save(
                                 save_path / str(epoch) / f'{i}.png')
        print('g_loss', np.mean(e_g_loss), 'd_loss', np.mean(e_d_loss))

        # save
        cp = {}
        for param in params:
            cp[param] = eval(param).state_dict()
        torch.save(cp, save_path / 'latest.pth'
                   )  #{param:eval(param).state_dict() for param in params})
Пример #6
0
def get_dataset(dataset, data_root, crop_size):
    """ Dataset And Augmentation
    """
    root_full_path = os.path.join(data_root)

    # train_transform = et.ExtCompose([
    #     # et.ExtResize( 512 ),
    #     # et.ExtRandomCrop(size=(crop_size, crop_size)),
    #     # et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
    #     # et.ExtRandomHorizontalFlip(),
    #     et.ExtToTensor(),
    #     et.ExtNormalize(mean=[0.485, 0.456, 0.406],
    #                     std=[0.229, 0.224, 0.225]),
    # ])

    # val_transform = et.ExtCompose([
    #     # et.ExtResize( 512 ),
    #     et.ExtToTensor(),
    #     et.ExtNormalize(mean=[0.485, 0.456, 0.406],
    #                     std=[0.229, 0.224, 0.225]),
    # ])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    if dataset.lower() == "cityscapes":
        print(f"[INFO] Fetching Cityscapes dataset from: {root_full_path}")
        train_dst = Cityscapes(
            root=data_root,
            split='train',
            transform=train_transform,
            target_transform=train_transform,
            #    transforms=(train_transform, train_transform)
        )
        val_dst = Cityscapes(
            root=data_root,
            split='val',
            transform=val_transform,
            target_transform=val_transform,
            #  transforms=(train_transform, train_transform)
        )
    else:
        print(f"[INFO] Fetching ApolloScape dataset from: {root_full_path}")
        train_dst = Apolloscape(root=root_full_path,
                                road="road02_seg",
                                transform=train_transform,
                                normalize_poses=True,
                                pose_format='quat',
                                train=True,
                                cache_transform=True,
                                stereo=False)

        val_dst = Apolloscape(root=root_full_path,
                              road="road02_seg",
                              transform=val_transform,
                              normalize_poses=True,
                              pose_format='quat',
                              train=False,
                              cache_transform=True,
                              stereo=False)

    return train_dst, val_dst
Пример #7
0
# Get data and define transformations
in_t = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

out_t = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.functional.pil_to_tensor,
])

tr_data = Cityscapes(
    './data',
    target_type=['semantic'],
    split='train',
    transform=in_t,
    target_transform=out_t,
)
v_data = Cityscapes(
    './data',
    target_type=['semantic'],
    split='val',
    transform=in_t,
    target_transform=out_t,
)

train_data = DataLoader(tr_data, batch_size=BATCH_SIZE, shuffle=True)
val_data = DataLoader(v_data, batch_size=BATCH_SIZE, shuffle=False)

# define model
Пример #8
0
from pytorch_lightning import LightningModule
from torchvision.datasets import Cityscapes
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision import models
import cv2

DATA_PATH = '/Users/eugennekhai/Downloads/133571_lumens-segmentation-dt2_Cityscapes/'

train_dataset = Cityscapes(DATA_PATH,
                           split='train',
                           mode='fine',
                           target_type='semantic')

val_dataset = Cityscapes(DATA_PATH,
                         split='val',
                         mode='fine',
                         target_type='semantic')

test_dataset = Cityscapes(DATA_PATH,
                          split='test',
                          mode='fine',
                          target_type='semantic')

img, smnt = val_dataset[0]

print(smnt)

img.show()
# while (True):
#     fr = cv2.resize(img, (600, 800))
#     cv2.imshow('test', fr)
Пример #9
0
    images, targets = list(zip(*batch))
    batched_imgs = cat_list(images, fill_value=0)
    batched_targets = cat_list(targets, fill_value=255)
    return batched_imgs, batched_targets


preprocess = transforms.Compose([
    transforms.RandomResizedCrop(513, scale=(0.63, 1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,
                                                          0.225]),
])

train_dataset = Cityscapes(DATA_PATH,
                           split='train',
                           mode='fine',
                           target_type='semantic',
                           transforms=preprocess)

val_dataset = Cityscapes(DATA_PATH,
                         split='val',
                         mode='fine',
                         target_type='semantic',
                         transforms=preprocess)

test_dataset = Cityscapes(DATA_PATH,
                          split='test',
                          mode='fine',
                          target_type='semantic',
                          transforms=preprocess)
Пример #10
0
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 10
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 30

dataset = Cityscapes(DataPath.CityScapes.HOME,
                     transform=transforms.Compose([
                         transforms.Resize(image_size),
                         transforms.CenterCrop(image_size),
                         transforms.ToTensor(),
                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                     ]),
                     target_transform=transforms.Compose([
                         transforms.Resize(image_size),
                         transforms.CenterCrop(image_size),
                         transforms.ToTensor()
                     ]))

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=workers)

labels_list: List[int] = [
    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22,
    23, 33
Пример #11
0
 transforms.CenterCrop(128),                #[3]
 transforms.ToTensor(),                     #[4]
 transforms.Normalize(                      #[5]
 mean=[0.485, 0.456, 0.406],                #[6]
 std=[0.229, 0.224, 0.225]                  #[7]
 )])

mask_transform = transforms.Compose([            #[1]
 transforms.Resize(128),                    #[2]
 transforms.CenterCrop(128),                #[3]
 transforms.ToTensor()                 #[4]
 ])


# trainset = CocoStuff10k(root="/home/nazar/PycharmProjects/coco", transform=transform)
trainset = Cityscapes(DataPath.HOME_STREET, transform=transform, target_transform=mask_transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=20, shuffle=True, num_workers=12)


class MaskClassifier(nn.Module):

    def __init__(self):

        super(MaskClassifier, self).__init__()
        self.resnet = resnet50(pretrained=True).to(ParallelConfig.MAIN_DEVICE)
        for param in self.resnet.parameters():
            param.requires_grad = False

    def classify(self, img: Tensor) -> Tensor:
        return self.resnet(img).softmax(dim=1)
Пример #12
0
 def __init__(self, split):
     super(CityscapesLoader, self).__init__()
     self.tensors_dataset = Cityscapes(root='./data/cityscapes', split=split, mode='fine', target_type='semantic',
                                       transform=to_resized_tensor, target_transform=to_resized_tensor)