Ejemplo n.º 1
0
    def __init__(self, config, hparams):

        self.config = config
        self.hparams = hparams

        if self.config.dataset == 'CamVid':
            train_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                utils.ext_transforms.ExtRandomHorizontalFlip(),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            test_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            train_dataset = CamVid(self.config.dataset_path,
                                   split='train',
                                   transform=train_transforms)
            test_dataset = CamVid(self.config.dataset_path,
                                  split='val',
                                  transform=test_transforms)

        # NYUv2 dataset
        if self.config.dataset == 'Nyu':
            train_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                utils.ext_transforms.ExtRandomHorizontalFlip(),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            test_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            train_dataset = NYUv2(self.config.dataset_path,
                                  split='train',
                                  transform=train_transforms)
            test_dataset = NYUv2(self.config.dataset_path,
                                 split=self.config.test_mode,
                                 transform=test_transforms)

        elif self.config.dataset == 'Cityscapes':
            train_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                utils.ext_transforms.ExtRandomHorizontalFlip(),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            test_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            train_dataset = Cityscapes(self.config.dataset_path,
                                       split='train',
                                       mode='fine',
                                       transform=train_transforms)
            test_dataset = Cityscapes(self.config.dataset_path,
                                      split='test',
                                      mode='fine',
                                      transform=test_transforms)

        self.train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.hparams.train_batch_size,
            shuffle=True,
            num_workers=6,
            drop_last=True)
        self.test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self.hparams.test_batch_size,
            shuffle=False,
            num_workers=6)

        if self.config.model == 'resnet50_pretrained':
            self.model = deeplabv3.deeplabv3_resnet50(num_classes=13,
                                                      dropout_p=0.5,
                                                      pretrained_backbone=True)
        if self.config.model == 'resnet100_pretrained':
            self.model = deeplabv3.deeplabv3_resnet101(
                num_classes=13, dropout_p=0.5, pretrained_backbone=True)
        if self.config.model == 'resnet50':
            self.model = deeplabv3.deeplabv3_resnet50(
                num_classes=13, dropout_p=0.5, pretrained_backbone=False)
        if self.config.model == 'mobilenet':
            self.model = deeplabv3.deeplabv3_mobilenet(
                num_classes=13, dropout_p=0.5, pretrained_backbone=False)

        self.device = torch.device("cuda")
        model_checkpoint = torch.load(self.hparams.model_checkpoint)
        self.model.load_state_dict(model_checkpoint['state_dict'])
        print("Teacher checkpoint loaded succesfully.")
        self.model = self.model.to(self.device)
        self.model.eval()

        self.G = wgan.DCGAN_G(self.hparams.img_size,
                              self.hparams.nz,
                              nc=3,
                              ngf=64,
                              ngpu=1)
        self.D = wgan.DCGAN_D(self.hparams.img_size,
                              self.hparams.nz,
                              nc=3,
                              ndf=64,
                              ngpu=1)
        self.G = self.G.to(self.device)
        self.D = self.D.to(self.device)
        self.G.apply(self.weights_init)
        self.D.apply(self.weights_init)

        self.fixed_noise = torch.randn(self.hparams.train_batch_size,
                                       self.hparams.nz,
                                       1,
                                       1,
                                       device=self.device)
        self.one = torch.FloatTensor([1]).to(self.device)
        self.m_one = torch.FloatTensor([-1]).to(self.device)
        self.gen_iterations = 0
        self.diversity_gt = torch.ones(
            self.config.num_classes,
            device=self.device) / self.config.num_classes

        # self.optimizerD = torch.optim.Adam(self.D.parameters(),
        #                                    lr=self.hparams.lr,
        #                                    betas=(0.5, 0.999))
        # self.optimizerG = torch.optim.Adam(self.G.parameters(),
        #                                    lr=self.hparams.lr,
        #                                    betas=(0.5, 0.999))
        self.optimizerD = torch.optim.RMSprop(self.D.parameters(),
                                              lr=self.hparams.lr)
        self.optimizerG = torch.optim.RMSprop(self.G.parameters(),
                                              lr=self.hparams.lr)

        if os.path.exists(self.config.log_dir):
            raise Exception("Log directory exists")
        self.logger = SummaryWriter(log_dir=self.config.log_dir)
Ejemplo n.º 2
0
    def __init__(self, config, hparams):

        self.config = config
        self.hparams = hparams

        if self.config.dataset == 'CamVid':
            train_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                utils.ext_transforms.ExtRandomHorizontalFlip(),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5,), (0.5,)),
            ])
            test_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5,), (0.5,)),
            ])
            train_dataset = CamVid(self.config.dataset_path, split='train', transform=train_transforms)
            test_dataset = CamVid(self.config.dataset_path, split=self.config.test_mode, transform=test_transforms)

        # NYUv2 dataset
        if self.config.dataset == 'Nyu':
            train_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                utils.ext_transforms.ExtRandomHorizontalFlip(),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            test_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            train_dataset = NYUv2(self.config.dataset_path,
                                   split='train',
                                   transform=train_transforms)
            test_dataset = NYUv2(self.config.dataset_path,
                                  split=self.config.test_mode,
                                  transform=test_transforms)

        elif self.config.dataset == 'Cityscapes':
            train_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                utils.ext_transforms.ExtRandomHorizontalFlip(),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5,), (0.5,)),
            ])
            test_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5,), (0.5,)),
            ])
            train_dataset = Cityscapes(self.config.dataset_path, split='train', mode='fine', transform=train_transforms)
            test_dataset = Cityscapes(self.config.dataset_path, split='val', mode='fine', transform=test_transforms)

        self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.train_batch_size, shuffle=True, num_workers=6)
        self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.test_batch_size, shuffle=False, num_workers=6)


        if self.config.teacher == 'resnet50_pretrained':
            self.teacher = deeplabv3.deeplabv3_resnet50(num_classes=13,
                                                      dropout_p=0.5,
                                                      pretrained_backbone=True)
        if self.config.teacher == 'resnet100_pretrained':
            self.teacher = deeplabv3.deeplabv3_resnet101(num_classes=13,
                                                      dropout_p=0.5,
                                                      pretrained_backbone=True)
        if self.config.teacher == 'resnet50':
            self.teacher = deeplabv3.deeplabv3_resnet50(
                num_classes=13, dropout_p=0.5, pretrained_backbone=False)
        if self.config.teacher == 'mobilenet':
            self.teacher = deeplabv3.deeplabv3_mobilenet(
                num_classes=13, dropout_p=0.5, pretrained_backbone=False)


        if self.config.model == 'resnet50_pretrained':
            self.model = deeplabv3.deeplabv3_resnet50(num_classes=13,
                                                      dropout_p=0.5,
                                                      pretrained_backbone=True)
        if self.config.model == 'resnet100_pretrained':
            self.model = deeplabv3.deeplabv3_resnet101(num_classes=13,
                                                      dropout_p=0.5,
                                                      pretrained_backbone=True)
        if self.config.model == 'resnet50':
            self.model = deeplabv3.deeplabv3_resnet50(
                num_classes=13, dropout_p=0.5, pretrained_backbone=False)
        if self.config.model == 'mobilenet':
            self.model = deeplabv3.deeplabv3_mobilenet(
                num_classes=13, dropout_p=0.5, pretrained_backbone=False)

        # Uses SGD optimizer for better performance than Adam
        self.optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
            momentum=self.hparams.momentum)

        self.device = torch.device("cuda")
        self.model = self.model.to(self.device)
        teacher_checkpoint = torch.load(self.config.teacher_checkpoint)
        self.teacher.load_state_dict(teacher_checkpoint['state_dict'])
        self.teacher = self.teacher.to(self.device)
        self.teacher.eval()
        print("Teacher checkpoint loaded successfully.")
        generator_checkpoint = torch.load(self.config.generator_checkpoint)
        self.hparams.nz = generator_checkpoint['hparams']['nz']
        self.G = dcgan.DcGanGenerator(nz=generator_checkpoint['hparams']['nz'])
        self.G.load_state_dict(generator_checkpoint['g_state_dict'])
        self.G.to(self.device)
        self.G.eval()
        print("Generator model loaded successfully")

        if self.hparams.lr_scheduler:
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, self.hparams.scheduler_step, self.hparams.scheduler_gamma)


        if os.path.exists(self.config.log_dir):
            raise Exception("Log directory exists")
        self.logger = SummaryWriter(log_dir=self.config.log_dir)
Ejemplo n.º 3
0
    def __init__(self, config, hparams):

        # Store config and hparams
        self.config = config
        self.hparams = hparams

        # Camvid Dataset
        if self.config.dataset == 'CamVid':
            train_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                utils.ext_transforms.ExtRandomHorizontalFlip(),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            test_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            train_dataset = CamVid(self.config.dataset_path,
                                   split='train',
                                   transform=train_transforms)
            test_dataset = CamVid(self.config.dataset_path,
                                  split=self.config.test_mode,
                                  transform=test_transforms)

        # NYUv2 dataset
        if self.config.dataset == 'Nyu':
            train_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                utils.ext_transforms.ExtRandomHorizontalFlip(),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            test_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            train_dataset = NYUv2(self.config.dataset_path,
                                  split='train',
                                  transform=train_transforms)
            test_dataset = NYUv2(self.config.dataset_path,
                                 split=self.config.test_mode,
                                 transform=test_transforms)

        # Cityscapes Dataset
        elif self.config.dataset == 'Cityscapes':
            train_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                utils.ext_transforms.ExtRandomHorizontalFlip(),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            test_transforms = utils.ext_transforms.ExtCompose([
                utils.ext_transforms.ExtResize(256),
                utils.ext_transforms.ExtToTensor(),
                utils.ext_transforms.ExtNormalize((0.5, ), (0.5, )),
            ])
            train_dataset = Cityscapes(self.config.dataset_path,
                                       split='train',
                                       mode='fine',
                                       transform=train_transforms)
            test_dataset = Cityscapes(self.config.dataset_path,
                                      split='test',
                                      mode='fine',
                                      transform=test_transforms)

        self.train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.hparams.train_batch_size,
            shuffle=True,
            num_workers=6)
        self.test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self.hparams.test_batch_size,
            shuffle=False,
            num_workers=6)

        if self.config.model == 'resnet50_pretrained':
            self.model = deeplabv3.deeplabv3_resnet50(num_classes=13,
                                                      dropout_p=0.5,
                                                      pretrained_backbone=True)
        if self.config.model == 'resnet100_pretrained':
            self.model = deeplabv3.deeplabv3_resnet101(
                num_classes=13, dropout_p=0.5, pretrained_backbone=True)
        if self.config.model == 'resnet50':
            self.model = deeplabv3.deeplabv3_resnet50(
                num_classes=13, dropout_p=0.5, pretrained_backbone=False)
        if self.config.model == 'mobilenet':
            self.model = deeplabv3.deeplabv3_mobilenet(
                num_classes=13, dropout_p=0.5, pretrained_backbone=False)

        # Uses SGD optimizer for better performance than Adam
        self.optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
            momentum=self.hparams.momentum)
        # self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)

        # Run on GPU only. Convert model to GPU
        self.device = torch.device("cuda")
        self.model = self.model.to(self.device)

        # Learning rate scheduler
        if self.hparams.lr_scheduler:
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer, self.hparams.scheduler_step,
                self.hparams.scheduler_gamma)

        # Initialise a new logging directory and a tensorboard logger
        if os.path.exists(self.config.log_dir):
            raise Exception("Log directory exists")
        self.logger = SummaryWriter(log_dir=self.config.log_dir)