示例#1
0
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True,
                              transform=tran_transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10_test'), train=False, download=True,
                                   transform=test_transform)

        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True,
                            transform=tran_transform)
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist_test'), train=False, download=True,
                                 transform=test_transform)

        elif self.config.data.dataset == 'CELEBA':
            if self.config.data.random_flip:
                dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train',
                                 transform=transforms.Compose([
                                     transforms.CenterCrop(140),
                                     transforms.Resize(self.config.data.image_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                 ]), download=True)
            else:
                dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train',
                                 transform=transforms.Compose([
                                     transforms.CenterCrop(140),
                                     transforms.Resize(self.config.data.image_size),
                                     transforms.ToTensor(),
                                 ]), download=True)

            test_dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba_test'), split='test',
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                  ]), download=True)

        elif self.config.data.dataset == 'SVHN':
            dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'), split='train', download=True,
                           transform=tran_transform)
            test_dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn_test'), split='test', download=True,
                                transform=test_transform)

        elif self.config.data.dataset == 'NYUv2':
            if self.config.data.random_flip is False:
                nyu_train_transform = nyu_test_transform = transforms.Compose([
                    transforms.CenterCrop((400, 400)),
                    transforms.Resize(32),
                    transforms.ToTensor()
                ])
            else:
                nyu_train_transform = transforms.Compose([
                    transforms.CenterCrop((400, 400)),
                    transforms.Resize(32),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor()
                ])
                nyu_test_transform = transforms.Compose([
                    transforms.CenterCrop((400, 400)),
                    transforms.Resize(32),
                    transforms.ToTensor()
                ])

            dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=True, download=True,
                            rgb_transform=nyu_train_transform, depth_transform=nyu_train_transform)
            test_dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=False, download=True,
                                 rgb_transform=nyu_test_transform, depth_transform=nyu_test_transform)

        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                num_workers=0)  # changed num_workers from 4 to 0
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=0, drop_last=True)  # changed num_workers from 4 to 0

        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end),
                               self.config.model.num_classes))).float().to(self.config.device)

        for epoch in range(self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.

                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                if self.config.data.dataset == 'NYUv2':
                    # concatenate depth map with image
                    y = y[0]
                    # code to see resized depth map
                    # input_gt_depth_image = y[0][0].data.cpu().numpy().astype(np.float32)
                    # plot.imsave('gt_depth_map_{}.png'.format(i), input_gt_depth_image,
                    #             cmap="viridis")
                    y = y.to(self.config.device)
                    X = torch.cat((X, y), 1)

                labels = torch.randint(0, len(sigmas), (X.shape[0],), device=X.device)

                if self.config.training.algo == 'dsm':
                    loss = anneal_dsm_score_estimation(score, X, labels, sigmas, self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    loss = anneal_sliced_score_estimation_vr(score, X, labels, sigmas,
                                                             n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                tb_logger.add_scalar('loss', loss, global_step=step)
                logging.info("step: {}, loss: {}".format(step, loss.item()))

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    score.eval()
                    try:
                        test_X, test_y = next(test_iter)
                    except StopIteration:
                        test_iter = iter(test_loader)
                        test_X, test_y = next(test_iter)

                    test_X = test_X.to(self.config.device)

                    test_X = test_X / 256. * 255. + torch.rand_like(test_X) / 256.
                    if self.config.data.logit_transform:
                        test_X = self.logit_transform(test_X)

                    if self.config.data.dataset == 'NYUv2':
                        test_y = test_y[0]
                        test_y = test_y.to(self.config.device)
                        test_X = torch.cat((test_X, test_y), 1)

                    test_labels = torch.randint(0, len(sigmas), (test_X.shape[0],), device=test_X.device)

                    with torch.no_grad():
                        test_dsm_loss = anneal_dsm_score_estimation(score, test_X, test_labels, sigmas,
                                                                    self.config.training.anneal_power)

                    tb_logger.add_scalar('test_dsm_loss', test_dsm_loss, global_step=step)

                if step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step)))
                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))
示例#2
0
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        dataset = LoadDataset('./ground turth',tran_transform)
        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4)


        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end),
                               self.config.model.num_classes))).float().to(self.config.device)

        for epoch in range(self.config.training.n_epochs):
            for i, X in enumerate(dataloader):
                X = torch.tensor(X,dtype=torch.float32)
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.
                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                labels = torch.randint(0, len(sigmas), (X.shape[0],), device=X.device)
                if self.config.training.algo == 'dsm':
                    loss = anneal_dsm_score_estimation(score, X, labels, sigmas, self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    loss = anneal_sliced_score_estimation_vr(score, X, labels, sigmas,
                                                             n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                tb_logger.add_scalar('loss', loss, global_step=step)
                logging.info("step: {}, loss: {}".format(step, loss.item()))

                if step >= self.config.training.n_iters:
                    return 0

                
                if step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step)))

                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets',
                                           'cifar10'),
                              train=True,
                              download=True,
                              transform=tran_transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets',
                                                'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=test_transform)
        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'),
                            train=True,
                            download=True,
                            transform=tran_transform)
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets',
                                              'mnist_test'),
                                 train=False,
                                 download=True,
                                 transform=test_transform)

        elif self.config.data.dataset == 'CELEBA':
            if self.config.data.random_flip:
                dataset = CelebA(
                    root=os.path.join(self.args.run, 'datasets', 'celeba'),
                    split='train',
                    transform=transforms.Compose([
                        transforms.CenterCrop(140),
                        transforms.Resize(self.config.data.image_size),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                    ]),
                    download=False)
            else:
                dataset = CelebA(
                    root=os.path.join(self.args.run, 'datasets', 'celeba'),
                    split='train',
                    transform=transforms.Compose([
                        transforms.CenterCrop(140),
                        transforms.Resize(self.config.data.image_size),
                        transforms.ToTensor(),
                    ]),
                    download=False)

            test_dataset = CelebA(
                root=os.path.join(self.args.run, 'datasets', 'celeba_test'),
                split='test',
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                ]),
                download=False)

        elif self.config.data.dataset == 'SVHN':
            dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'),
                           split='train',
                           download=True,
                           transform=tran_transform)
            test_dataset = SVHN(os.path.join(self.args.run, 'datasets',
                                             'svhn_test'),
                                split='test',
                                download=True,
                                transform=test_transform)

        dataloader = DataLoader(dataset,
                                batch_size=self.config.training.batch_size,
                                shuffle=True,
                                num_workers=4)
        test_loader = DataLoader(test_dataset,
                                 batch_size=self.config.training.batch_size,
                                 shuffle=True,
                                 num_workers=4,
                                 drop_last=True)

        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size**2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(
                np.linspace(np.log(self.config.model.sigma_begin),
                            np.log(self.config.model.sigma_end),
                            self.config.model.num_classes))).float().to(
                                self.config.device)

        time_record = []
        for epoch in range(self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.

                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                labels = torch.randint(0,
                                       len(sigmas), (X.shape[0], ),
                                       device=X.device)
                if self.config.training.algo == 'dsm':
                    t = time.time()
                    loss = anneal_dsm_score_estimation(
                        score, X, labels, sigmas,
                        self.config.training.anneal_power)
                elif self.config.training.algo == 'dsm_tracetrick':
                    t = time.time()
                    loss = anneal_dsm_score_estimation_TraceTrick(
                        score, X, labels, sigmas,
                        self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    t = time.time()
                    loss = anneal_sliced_score_estimation_vr(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)
                elif self.config.training.algo == 'esm_scorenet':
                    t = time.time()
                    loss = anneal_ESM_scorenet(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)
                elif self.config.training.algo == 'esm_scorenet_VR':
                    t = time.time()
                    loss = anneal_ESM_scorenet_VR(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                t = time.time() - t
                time_record.append(t)

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    tb_logger.add_scalar('loss', loss, global_step=step)
                    logging.info(
                        "step: {}, loss: {}, time per step: {:.3f} +- {:.3f} ms"
                        .format(step, loss.item(),
                                np.mean(time_record) * 1e3,
                                np.std(time_record) * 1e3))

                    # if step % 2000 == 0:
                    #     score.eval()
                    #     try:
                    #         test_X, test_y = next(test_iter)
                    #     except StopIteration:
                    #         test_iter = iter(test_loader)
                    #         test_X, test_y = next(test_iter)

                    #     test_X = test_X.to(self.config.device)
                    #     test_X = test_X / 256. * 255. + torch.rand_like(test_X) / 256.

                    #     if self.config.data.logit_transform:
                    #         test_X = self.logit_transform(test_X)

                    #     test_labels = torch.randint(0, len(sigmas), (test_X.shape[0],), device=test_X.device)

                    #     #if self.config.training.algo == 'dsm':
                    #     with torch.no_grad():
                    #         test_dsm_loss = anneal_dsm_score_estimation(score, test_X, test_labels, sigmas,
                    #                                                         self.config.training.anneal_power)

                    #     tb_logger.add_scalar('test_dsm_loss', test_dsm_loss, global_step=step)
                    #     logging.info("step: {}, test dsm loss: {}".format(step, test_dsm_loss.item()))

                    # elif self.config.training.algo == 'ssm':
                    #     test_ssm_loss = anneal_sliced_score_estimation_vr(score, test_X, test_labels, sigmas,
                    #                                          n_particles=self.config.training.n_particles)

                    #     tb_logger.add_scalar('test_ssm_loss', test_ssm_loss, global_step=step)
                    #     logging.info("step: {}, test ssm loss: {}".format(step, test_ssm_loss.item()))

                if step >= 140000 and step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(
                        states,
                        os.path.join(self.args.log,
                                     'checkpoint_{}.pth'.format(step)))
                    torch.save(states,
                               os.path.join(self.args.log, 'checkpoint.pth'))