コード例 #1
0
    def addentry(self, step, score, train_losses):
        if step % 100 == 0:
            test_score = self.ema_helper.ema_copy(score, self.args.device)
            test_score.eval()
            try:
                test_X, _ = next(self.test_iter)
            except StopIteration:
                self.test_iter = iter(self.testloader)
                test_X, _ = next(self.test_iter)

            test_X = test_X.to(self.args.device)
            test_X = data_transform(self.config.data, test_X)

            if self.hook is not None:
                self.hook.update_step(step)

            test_dsm_loss, _, _ = anneal_dsm_score_estimation(self.args, test_score, test_X, self.sigmas, hook=self.hook)

            if self.args.adversarial:
                self.tb_logger.add_scalar('loss_D', train_losses[0], global_step=step)
                self.tb_logger.add_scalar('loss_G_DAE', train_losses[1], global_step=step)
                self.tb_logger.add_scalar('loss_G_GAN', train_losses[2], global_step=step)
            else:
                self.tb_logger.add_scalar('train_loss', train_losses[0], global_step=step)

            self.tb_logger.add_scalar('test_loss', test_dsm_loss, global_step=step)
            if self.args.adversarial:
                _train_loss_string = "tloss_D: {:10.4f}  |  tloss_dae: {:10.4f}  |  tloss_adv: {:10.4f}"
            else:
                _train_loss_string = "train_loss: {:10.4f}"

            string = "step: {:8d}  |  " + _train_loss_string + "  ||  test_loss: {:10.4f}"
            logging.info(string.format(step, *train_losses, test_dsm_loss.item()))
            print(string.format(step, *train_losses, test_dsm_loss.item()))
コード例 #2
0
    def test(self):
        score = get_model(self.config)
        score = torch.nn.DataParallel(score)

        sigmas = get_sigmas(self.config)

        dataset, test_dataset = get_dataset(self.args, self.config)
        test_dataloader = DataLoader(test_dataset,
                                     batch_size=self.config.test.batch_size,
                                     shuffle=True,
                                     num_workers=self.config.data.num_workers,
                                     drop_last=True)

        verbose = False
        for ckpt in tqdm.tqdm(range(self.config.test.begin_ckpt,
                                    self.config.test.end_ckpt + 1, 5000),
                              desc="processing ckpt:"):
            states = torch.load(os.path.join(self.args.log_path,
                                             f'checkpoint_{ckpt}.pth'),
                                map_location=self.config.device)

            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(score)
                ema_helper.load_state_dict(states[-1])
                ema_helper.ema(score)
            else:
                score.load_state_dict(states[0])

            score.eval()

            step = 0
            mean_loss = 0.
            mean_grad_norm = 0.
            average_grad_scale = 0.
            for x, y in test_dataloader:
                step += 1

                x = x.to(self.config.device)
                x = data_transform(self.config, x)

                with torch.no_grad():
                    test_loss = anneal_dsm_score_estimation(
                        score, x, sigmas, None,
                        self.config.training.anneal_power)
                    if verbose:
                        logging.info("step: {}, test_loss: {}".format(
                            step, test_loss.item()))

                    mean_loss += test_loss.item()

            mean_loss /= step
            mean_grad_norm /= step
            average_grad_scale /= step

            logging.info("ckpt: {}, average test loss: {}".format(
                ckpt, mean_loss))
コード例 #3
0
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        dataset = LoadDataset('./ground turth',tran_transform)
        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4)


        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end),
                               self.config.model.num_classes))).float().to(self.config.device)

        for epoch in range(self.config.training.n_epochs):
            for i, X in enumerate(dataloader):
                X = torch.tensor(X,dtype=torch.float32)
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.
                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                labels = torch.randint(0, len(sigmas), (X.shape[0],), device=X.device)
                if self.config.training.algo == 'dsm':
                    loss = anneal_dsm_score_estimation(score, X, labels, sigmas, self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    loss = anneal_sliced_score_estimation_vr(score, X, labels, sigmas,
                                                             n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                tb_logger.add_scalar('loss', loss, global_step=step)
                logging.info("step: {}, loss: {}".format(step, loss.item()))

                if step >= self.config.training.n_iters:
                    return 0

                
                if step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step)))

                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))
コード例 #4
0
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True,
                              transform=tran_transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10_test'), train=False, download=True,
                                   transform=test_transform)

        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True,
                            transform=tran_transform)
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist_test'), train=False, download=True,
                                 transform=test_transform)

        elif self.config.data.dataset == 'CELEBA':
            if self.config.data.random_flip:
                dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train',
                                 transform=transforms.Compose([
                                     transforms.CenterCrop(140),
                                     transforms.Resize(self.config.data.image_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                 ]), download=True)
            else:
                dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba'), split='train',
                                 transform=transforms.Compose([
                                     transforms.CenterCrop(140),
                                     transforms.Resize(self.config.data.image_size),
                                     transforms.ToTensor(),
                                 ]), download=True)

            test_dataset = CelebA(root=os.path.join(self.args.run, 'datasets', 'celeba_test'), split='test',
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                  ]), download=True)

        elif self.config.data.dataset == 'SVHN':
            dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'), split='train', download=True,
                           transform=tran_transform)
            test_dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn_test'), split='test', download=True,
                                transform=test_transform)

        elif self.config.data.dataset == 'NYUv2':
            if self.config.data.random_flip is False:
                nyu_train_transform = nyu_test_transform = transforms.Compose([
                    transforms.CenterCrop((400, 400)),
                    transforms.Resize(32),
                    transforms.ToTensor()
                ])
            else:
                nyu_train_transform = transforms.Compose([
                    transforms.CenterCrop((400, 400)),
                    transforms.Resize(32),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor()
                ])
                nyu_test_transform = transforms.Compose([
                    transforms.CenterCrop((400, 400)),
                    transforms.Resize(32),
                    transforms.ToTensor()
                ])

            dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=True, download=True,
                            rgb_transform=nyu_train_transform, depth_transform=nyu_train_transform)
            test_dataset = NYUv2(os.path.join(self.args.run, 'datasets', 'nyuv2'), train=False, download=True,
                                 rgb_transform=nyu_test_transform, depth_transform=nyu_test_transform)

        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                num_workers=0)  # changed num_workers from 4 to 0
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=0, drop_last=True)  # changed num_workers from 4 to 0

        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end),
                               self.config.model.num_classes))).float().to(self.config.device)

        for epoch in range(self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.

                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                if self.config.data.dataset == 'NYUv2':
                    # concatenate depth map with image
                    y = y[0]
                    # code to see resized depth map
                    # input_gt_depth_image = y[0][0].data.cpu().numpy().astype(np.float32)
                    # plot.imsave('gt_depth_map_{}.png'.format(i), input_gt_depth_image,
                    #             cmap="viridis")
                    y = y.to(self.config.device)
                    X = torch.cat((X, y), 1)

                labels = torch.randint(0, len(sigmas), (X.shape[0],), device=X.device)

                if self.config.training.algo == 'dsm':
                    loss = anneal_dsm_score_estimation(score, X, labels, sigmas, self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    loss = anneal_sliced_score_estimation_vr(score, X, labels, sigmas,
                                                             n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                tb_logger.add_scalar('loss', loss, global_step=step)
                logging.info("step: {}, loss: {}".format(step, loss.item()))

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    score.eval()
                    try:
                        test_X, test_y = next(test_iter)
                    except StopIteration:
                        test_iter = iter(test_loader)
                        test_X, test_y = next(test_iter)

                    test_X = test_X.to(self.config.device)

                    test_X = test_X / 256. * 255. + torch.rand_like(test_X) / 256.
                    if self.config.data.logit_transform:
                        test_X = self.logit_transform(test_X)

                    if self.config.data.dataset == 'NYUv2':
                        test_y = test_y[0]
                        test_y = test_y.to(self.config.device)
                        test_X = torch.cat((test_X, test_y), 1)

                    test_labels = torch.randint(0, len(sigmas), (test_X.shape[0],), device=test_X.device)

                    with torch.no_grad():
                        test_dsm_loss = anneal_dsm_score_estimation(score, test_X, test_labels, sigmas,
                                                                    self.config.training.anneal_power)

                    tb_logger.add_scalar('test_dsm_loss', test_dsm_loss, global_step=step)

                if step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step)))
                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))
コード例 #5
0
    def train(self):
        dataset, test_dataset = get_dataset(self.args, self.config)
        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                num_workers=self.config.data.num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=self.config.data.num_workers, drop_last=True)
        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_logger = self.config.tb_logger

        score = get_model(self.config)

        score = torch.nn.DataParallel(score)
        optimizer = get_optimizer(self.config, score.parameters())

        start_epoch = 0
        step = 0

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(score)

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log_path, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            ### Make sure we can resume with different eps
            states[1]['param_groups'][0]['eps'] = self.config.optim.eps
            optimizer.load_state_dict(states[1])
            start_epoch = states[2]
            step = states[3]
            if self.config.model.ema:
                ema_helper.load_state_dict(states[4])

        sigmas = get_sigmas(self.config)

        if self.config.training.log_all_sigmas:
            ### Commented out training time logging to save time.
            test_loss_per_sigma = [None for _ in range(len(sigmas))]

            def hook(loss, labels):
                # for i in range(len(sigmas)):
                #     if torch.any(labels == i):
                #         test_loss_per_sigma[i] = torch.mean(loss[labels == i])
                pass

            def tb_hook():
                # for i in range(len(sigmas)):
                #     if test_loss_per_sigma[i] is not None:
                #         tb_logger.add_scalar('test_loss_sigma_{}'.format(i), test_loss_per_sigma[i],
                #                              global_step=step)
                pass

            def test_hook(loss, labels):
                for i in range(len(sigmas)):
                    if torch.any(labels == i):
                        test_loss_per_sigma[i] = torch.mean(loss[labels == i])

            def test_tb_hook():
                for i in range(len(sigmas)):
                    if test_loss_per_sigma[i] is not None:
                        tb_logger.add_scalar('test_loss_sigma_{}'.format(i), test_loss_per_sigma[i],
                                             global_step=step)

        else:
            hook = test_hook = None

            def tb_hook():
                pass

            def test_tb_hook():
                pass

        for epoch in range(start_epoch, self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                score.train()
                step += 1

                X = X.to(self.config.device)
                X = data_transform(self.config, X)

                loss = anneal_dsm_score_estimation(score, X, sigmas, None,
                                                   self.config.training.anneal_power,
                                                   hook)
                tb_logger.add_scalar('loss', loss, global_step=step)
                tb_hook()

                logging.info("step: {}, loss: {}".format(step, loss.item()))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if self.config.model.ema:
                    ema_helper.update(score)

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    if self.config.model.ema:
                        test_score = ema_helper.ema_copy(score)
                    else:
                        test_score = score

                    test_score.eval()
                    try:
                        test_X, test_y = next(test_iter)
                    except StopIteration:
                        test_iter = iter(test_loader)
                        test_X, test_y = next(test_iter)

                    test_X = test_X.to(self.config.device)
                    test_X = data_transform(self.config, test_X)

                    with torch.no_grad():
                        test_dsm_loss = anneal_dsm_score_estimation(test_score, test_X, sigmas, None,
                                                                    self.config.training.anneal_power,
                                                                    hook=test_hook)
                        tb_logger.add_scalar('test_loss', test_dsm_loss, global_step=step)
                        test_tb_hook()
                        logging.info("step: {}, test_loss: {}".format(step, test_dsm_loss.item()))

                        del test_score

                if step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                        epoch,
                        step,
                    ]
                    if self.config.model.ema:
                        states.append(ema_helper.state_dict())

                    torch.save(states, os.path.join(self.args.log_path, 'checkpoint_{}.pth'.format(step)))
                    torch.save(states, os.path.join(self.args.log_path, 'checkpoint.pth'))

                    if self.config.training.snapshot_sampling:
                        if self.config.model.ema:
                            test_score = ema_helper.ema_copy(score)
                        else:
                            test_score = score

                        test_score.eval()

                        ## Different part from NeurIPS 2019.
                        ## Random state will be affected because of sampling during training time.
                        init_samples = torch.rand(36, self.config.data.channels,
                                                  self.config.data.image_size, self.config.data.image_size,
                                                  device=self.config.device)
                        init_samples = data_transform(self.config, init_samples)

                        all_samples = anneal_Langevin_dynamics(init_samples, test_score, sigmas.cpu().numpy(),
                                                               self.config.sampling.n_steps_each,
                                                               self.config.sampling.step_lr,
                                                               final_only=True, verbose=True,
                                                               denoise=self.config.sampling.denoise)

                        sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels,
                                                      self.config.data.image_size,
                                                      self.config.data.image_size)

                        sample = inverse_data_transform(self.config, sample)

                        image_grid = make_grid(sample, 6)
                        save_image(image_grid,
                                   os.path.join(self.args.log_sample_path, 'image_grid_{}.png'.format(step)))
                        torch.save(sample, os.path.join(self.args.log_sample_path, 'samples_{}.pth'.format(step)))

                        del test_score
                        del all_samples
コード例 #6
0
    def train(self):
        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets',
                                           'cifar10'),
                              train=True,
                              download=True,
                              transform=tran_transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets',
                                                'cifar10_test'),
                                   train=False,
                                   download=True,
                                   transform=test_transform)
        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'),
                            train=True,
                            download=True,
                            transform=tran_transform)
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets',
                                              'mnist_test'),
                                 train=False,
                                 download=True,
                                 transform=test_transform)

        elif self.config.data.dataset == 'CELEBA':
            if self.config.data.random_flip:
                dataset = CelebA(
                    root=os.path.join(self.args.run, 'datasets', 'celeba'),
                    split='train',
                    transform=transforms.Compose([
                        transforms.CenterCrop(140),
                        transforms.Resize(self.config.data.image_size),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                    ]),
                    download=False)
            else:
                dataset = CelebA(
                    root=os.path.join(self.args.run, 'datasets', 'celeba'),
                    split='train',
                    transform=transforms.Compose([
                        transforms.CenterCrop(140),
                        transforms.Resize(self.config.data.image_size),
                        transforms.ToTensor(),
                    ]),
                    download=False)

            test_dataset = CelebA(
                root=os.path.join(self.args.run, 'datasets', 'celeba_test'),
                split='test',
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                ]),
                download=False)

        elif self.config.data.dataset == 'SVHN':
            dataset = SVHN(os.path.join(self.args.run, 'datasets', 'svhn'),
                           split='train',
                           download=True,
                           transform=tran_transform)
            test_dataset = SVHN(os.path.join(self.args.run, 'datasets',
                                             'svhn_test'),
                                split='test',
                                download=True,
                                transform=test_transform)

        dataloader = DataLoader(dataset,
                                batch_size=self.config.training.batch_size,
                                shuffle=True,
                                num_workers=4)
        test_loader = DataLoader(test_dataset,
                                 batch_size=self.config.training.batch_size,
                                 shuffle=True,
                                 num_workers=4,
                                 drop_last=True)

        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size**2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        score = CondRefineNetDilated(self.config).to(self.config.device)

        score = torch.nn.DataParallel(score)

        optimizer = self.get_optimizer(score.parameters())

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            optimizer.load_state_dict(states[1])

        step = 0

        sigmas = torch.tensor(
            np.exp(
                np.linspace(np.log(self.config.model.sigma_begin),
                            np.log(self.config.model.sigma_end),
                            self.config.model.num_classes))).float().to(
                                self.config.device)

        time_record = []
        for epoch in range(self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                step += 1
                score.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.

                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                labels = torch.randint(0,
                                       len(sigmas), (X.shape[0], ),
                                       device=X.device)
                if self.config.training.algo == 'dsm':
                    t = time.time()
                    loss = anneal_dsm_score_estimation(
                        score, X, labels, sigmas,
                        self.config.training.anneal_power)
                elif self.config.training.algo == 'dsm_tracetrick':
                    t = time.time()
                    loss = anneal_dsm_score_estimation_TraceTrick(
                        score, X, labels, sigmas,
                        self.config.training.anneal_power)
                elif self.config.training.algo == 'ssm':
                    t = time.time()
                    loss = anneal_sliced_score_estimation_vr(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)
                elif self.config.training.algo == 'esm_scorenet':
                    t = time.time()
                    loss = anneal_ESM_scorenet(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)
                elif self.config.training.algo == 'esm_scorenet_VR':
                    t = time.time()
                    loss = anneal_ESM_scorenet_VR(
                        score,
                        X,
                        labels,
                        sigmas,
                        n_particles=self.config.training.n_particles)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                t = time.time() - t
                time_record.append(t)

                if step >= self.config.training.n_iters:
                    return 0

                if step % 100 == 0:
                    tb_logger.add_scalar('loss', loss, global_step=step)
                    logging.info(
                        "step: {}, loss: {}, time per step: {:.3f} +- {:.3f} ms"
                        .format(step, loss.item(),
                                np.mean(time_record) * 1e3,
                                np.std(time_record) * 1e3))

                    # if step % 2000 == 0:
                    #     score.eval()
                    #     try:
                    #         test_X, test_y = next(test_iter)
                    #     except StopIteration:
                    #         test_iter = iter(test_loader)
                    #         test_X, test_y = next(test_iter)

                    #     test_X = test_X.to(self.config.device)
                    #     test_X = test_X / 256. * 255. + torch.rand_like(test_X) / 256.

                    #     if self.config.data.logit_transform:
                    #         test_X = self.logit_transform(test_X)

                    #     test_labels = torch.randint(0, len(sigmas), (test_X.shape[0],), device=test_X.device)

                    #     #if self.config.training.algo == 'dsm':
                    #     with torch.no_grad():
                    #         test_dsm_loss = anneal_dsm_score_estimation(score, test_X, test_labels, sigmas,
                    #                                                         self.config.training.anneal_power)

                    #     tb_logger.add_scalar('test_dsm_loss', test_dsm_loss, global_step=step)
                    #     logging.info("step: {}, test dsm loss: {}".format(step, test_dsm_loss.item()))

                    # elif self.config.training.algo == 'ssm':
                    #     test_ssm_loss = anneal_sliced_score_estimation_vr(score, test_X, test_labels, sigmas,
                    #                                          n_particles=self.config.training.n_particles)

                    #     tb_logger.add_scalar('test_ssm_loss', test_ssm_loss, global_step=step)
                    #     logging.info("step: {}, test ssm loss: {}".format(step, test_ssm_loss.item()))

                if step >= 140000 and step % self.config.training.snapshot_freq == 0:
                    states = [
                        score.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(
                        states,
                        os.path.join(self.args.log,
                                     'checkpoint_{}.pth'.format(step)))
                    torch.save(states,
                               os.path.join(self.args.log, 'checkpoint.pth'))
コード例 #7
0
    def run(self):

        self.config.input_dim = self.config.data.channels * self.config.data.image_size ** 2

        dataloader, testloader = self.get_dataloader(bs=self.config.training.batch_size, training=True)
        sample_loader = self.get_dataloader(bs=36)

        sigmas = self.get_sigmas(training=True)
        sampling_sigmas = self.get_sigmas(npy=True)
        score = self.get_model()
        optimizer = self.get_optimizer(score.parameters())

        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(score)
        else:
            ema_helper = DummyEMA()

        if self.args.adversarial:
            D = self._setup_discriminator()
            optimizerD = self.get_optimizer(D.parameters(), adv=True)
            D_loss_function, G_loss_function = adv_loss(self.config.adversarial.adv_loss, self.config.training, self.args.device)
        else:
            D = optimizerD = D_loss_function = G_loss_function = None

        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log_path, 'checkpoint.pth'))
            score.load_state_dict(states[0])
            ### Make sure we can resume with different eps
            states[1]['param_groups'][0]['eps'] = self.config.optim.eps
            optimizer.load_state_dict(states[1])
            start_epoch = states[2]
            step = states[3]
            ema_helper.load_state_dict(states[4])
            if self.args.adversarial:
                D.load_state_dict(states[5])
        else:
            start_epoch = 0
            step = 0

        hook = Hook(self.args.tb_logger, len(sigmas)) if self.config.training.log_all_sigmas else None
        testlogger = Logger(self.args, self.config, sigmas, testloader, ema_helper)

        kwargs = {'sigmas': sampling_sigmas, 'final_only': True, 'nsigma': self.config.sampling.nsigma,
                  'step_lr': self.config.sampling.step_lr, 'target': self.args.target, 'noise_first': self.config.sampling.noise_first}

        # Estimate maximum sigma that we should be using
        if self.args.compute_approximate_sigma_max:
            with torch.no_grad():
                current_max_dist = 0
                for i, (X, y) in enumerate(dataloader):

                    X = X.to(self.args.device)
                    X = data_transform(self.config.data, X)
                    X_ = X.view(X.shape[0], -1)
                    max_dist = torch.cdist(X_, X_).max().item()

                    if current_max_dist < max_dist:
                        current_max_dist = max_dist
                    print(current_max_dist)
                print('Final, max eucledian distance: {}'.format(current_max_dist))
                return ()

        D_step = 0
        for epoch in range(start_epoch, self.config.training.n_epochs):
            for i, (X, y) in enumerate(dataloader):
                score.train()

                X = X.to(self.args.device)
                X = data_transform(self.config.data, X)

                ##### Discriminator steps #####
                if self.args.adversarial:
                    D_step += 1
                    """ GAN Discriminator update (at every 'scorenetwork' update)"""
                    loss_D = self.update_adversarial_discriminator(X, y, sigmas, score, D, optimizerD, D_loss_function)

                ##### Score Network step #####
                if not self.args.adversarial or D_step >= self.config.adversarial.D_steps:  # Only update Score network if Discriminator has done all its steps
                    D_step = 0  # We reset the discriminator counter
                    step += 1

                    """ Score network update """
                    if hook is not None:
                        hook.update_step(step)
                    loss_dae, fake_denoised_X, scores_ = anneal_dsm_score_estimation(self.args, score, X, sigmas, hook=hook)

                    if self.args.adversarial:
                        # Tells me how 'real' fake_denoised_X looks. loss_adv high means the discriminator found
                        # it easy to tell they were fake.
                        if self.config.adversarial.adv_clamp:
                            fake_denoised_X_ = fake_denoised_X.clamp(0, 1)
                        else:
                            fake_denoised_X_ = fake_denoised_X
                        y_pred = D(X)
                        y_pred_fake = D(fake_denoised_X_)
                        loss_adv = self.config.adversarial.lambda_G_gan * G_loss_function(y_pred, y_pred_fake)
                        loss = self.config.adversarial.lambda_dae * loss_dae + loss_adv
                        _losses = [loss_D.item(), loss_dae.item(), loss_adv.item()]
                    else:
                        loss = loss_dae
                        _losses = [loss.item()]

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    ema_helper.update(score)
                    testlogger.addentry(step, score, _losses)

                    _force_end = step == self.config.training.n_iters
                    if step % self.config.training.snapshot_freq == 0 or _force_end:
                        states = [score.state_dict(), optimizer.state_dict(), epoch, step, ema_helper.state_dict()]

                        torch.save(states, os.path.join(self.args.log_path, 'checkpoint_{}.pth'.format(step)))
                        torch.save(states, os.path.join(self.args.log_path, 'checkpoint.pth'))
                        if _force_end:
                            return 1

                    if step % self.config.training.snapshot_sampling_freq == 0 and self.config.training.snapshot_sampling:
                        test_score = ema_helper.ema_copy(score, self.args.device)
                        test_score.eval()
                        kwargs['scorenet'] = test_score
                        self.sample(sample_loader, saveimages=True, kwargs=kwargs, gridsize=36, bs=36, ckpt_id=step)