def train(self, conditional=True):
        if conditional:
            print('USING CONDITIONAL DSM')

        if self.config.data.random_flip is False:
            tran_transform = test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])
        else:
            tran_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize(self.config.data.image_size),
                transforms.ToTensor()
            ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets'), train=True, download=True,
                              transform=tran_transform)

        elif self.config.data.dataset == 'MNIST':
            print('RUNNING REDUCED MNIST')
            dataset = MNIST(os.path.join(self.args.run, 'datasets'), train=True, download=True,
                            transform=tran_transform)

        elif self.config.data.dataset == 'FashionMNIST':
            dataset = FashionMNIST(os.path.join(self.args.run, 'datasets'), train=True, download=True,
                                   transform=tran_transform)

        elif self.config.data.dataset == 'MNIST_transferBaseline':
            # use same dataset as transfer_nets.py
            # we can also use the train dataset since the digits are unseen anyway
            dataset = MNIST(os.path.join(self.args.run, 'datasets'), train=False, download=True,
                            transform=test_transform)
            print('TRANSFER BASELINES !! Subset size: ' + str(self.subsetSize))

        elif self.config.data.dataset == 'CIFAR10_transferBaseline':
            # use same dataset as transfer_nets.py
            # we can also use the train dataset since the digits are unseen anyway
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets'), train=False, download=True,
                              transform=test_transform)
            print('TRANSFER BASELINES !! Subset size: ' + str(self.subsetSize))

        elif self.config.data.dataset == 'FashionMNIST_transferBaseline':
            # use same dataset as transfer_nets.py
            # we can also use the train dataset since the digits are unseen anyway
            dataset = FashionMNIST(os.path.join(self.args.run, 'datasets'), train=False, download=True,
                                   transform=test_transform)
            print('TRANSFER BASELINES !! Subset size: ' + str(self.subsetSize))

        else:
            raise ValueError('Unknown config dataset {}'.format(self.config.data.dataset))

        # apply collation
        if self.config.data.dataset in ['MNIST', 'CIFAR10', 'FashionMNIST']:
            collate_helper = lambda batch: my_collate(batch, nSeg=self.nSeg)
            print('Subset size: ' + str(self.subsetSize))
            id_range = list(range(self.subsetSize))
            dataset = torch.utils.data.Subset(dataset, id_range)
            dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=0,
                                    collate_fn=collate_helper)

        elif self.config.data.dataset in ['MNIST_transferBaseline', 'CIFAR10_transferBaseline',
                                          'FashionMNIST_transferBaseline']:
            # trains a model on only digits 8,9 from scratch
            print('Subset size: ' + str(self.subsetSize))
            id_range = list(range(self.subsetSize))
            dataset = torch.utils.data.Subset(dataset, id_range)
            dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=0,
                                    drop_last=True, collate_fn=my_collate_rev)
            print('loaded reduced subset')
        else:
            raise ValueError('Unknown config dataset {}'.format(self.config.data.dataset))

        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        # define the g network
        energy_net_finalLayer = torch.ones((self.config.data.image_size * self.config.data.image_size, self.nSeg)).to(
            self.config.device)
        energy_net_finalLayer.requires_grad_()

        # define the f network
        enet = RefineNetDilated(self.config).to(self.config.device)
        enet = torch.nn.DataParallel(enet)

        # training
        optimizer = self.get_optimizer(list(enet.parameters()) + [energy_net_finalLayer])
        step = 0
        loss_track_epochs = []
        for epoch in range(self.config.training.n_epochs):
            loss_vals = []
            for i, (X, y) in enumerate(dataloader):
                step += 1

                enet.train()
                X = X.to(self.config.device)
                X = X / 256. * 255. + torch.rand_like(X) / 256.
                if self.config.data.logit_transform:
                    X = self.logit_transform(X)

                y -= y.min()  # need to ensure its zero centered !
                if conditional:
                    loss = conditional_dsm(enet, X, y, energy_net_finalLayer, sigma=0.01)
                else:
                    loss = dsm(enet, X, sigma=0.01)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                logging.info("step: {}, loss: {}, maxLabel: {}".format(step, loss.item(), y.max()))
                loss_vals.append(loss.item())
                loss_track_epochs.append(loss.item())

                if step >= self.config.training.n_iters:
                    # save final checkpoints for distrubution!
                    states = [
                        enet.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(states, os.path.join(self.args.checkpoints, 'checkpoint_{}.pth'.format(step)))
                    torch.save(states, os.path.join(self.args.checkpoints, 'checkpoint.pth'))
                    torch.save([energy_net_finalLayer], os.path.join(self.args.checkpoints, 'finalLayerweights_.pth'))
                    pickle.dump(energy_net_finalLayer,
                                open(os.path.join(self.args.checkpoints, 'finalLayerweights.p'), 'wb'))
                    return 0

                if step % self.config.training.snapshot_freq == 0:
                    print('checkpoint at step: {}'.format(step))
                    # save checkpoint for transfer learning! !
                    torch.save([energy_net_finalLayer], os.path.join(self.args.log, 'finalLayerweights_.pth'))
                    pickle.dump(energy_net_finalLayer,
                                open(os.path.join(self.args.log, 'finalLayerweights.p'), 'wb'))
                    states = [
                        enet.state_dict(),
                        optimizer.state_dict(),
                    ]
                    torch.save(states, os.path.join(self.args.log, 'checkpoint_{}.pth'.format(step)))
                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))

            if self.config.data.dataset in ['MNIST_transferBaseline', 'CIFAR10_transferBaseline']:
                # save loss track during epoch for transfer baseline
                pickle.dump(loss_vals,
                            open(os.path.join(self.args.run, self.args.dataset + '_Baseline_Size' + str(
                                self.subsetSize) + "_Seed" + str(self.seed) + '.p'), 'wb'))

        if self.config.data.dataset in ['MNIST_transferBaseline', 'CIFAR10_transferBaseline']:
            # save loss track during epoch for transfer baseline
            pickle.dump(loss_track_epochs,
                        open(os.path.join(self.args.run, self.args.dataset + '_Baseline_epochs_Size' + str(
                            self.subsetSize) + "_Seed" + str(self.seed) + '.p'), 'wb'))

        # save final checkpoints for distrubution!
        states = [
            enet.state_dict(),
            optimizer.state_dict(),
        ]
        torch.save(states, os.path.join(self.args.checkpoints, 'checkpoint_{}.pth'.format(step)))
        torch.save(states, os.path.join(self.args.checkpoints, 'checkpoint.pth'))
        torch.save([energy_net_finalLayer], os.path.join(self.args.checkpoints, 'finalLayerweights_.pth'))
        pickle.dump(energy_net_finalLayer,
                    open(os.path.join(self.args.checkpoints, 'finalLayerweights.p'), 'wb'))
    def finalize(
        self, dkef, tb_logger, train_data, val_data, test_data, collate_fn, train_mode
    ):
        lambda_params = [
            param for (name, param) in dkef.named_parameters() if "lambd" in name
        ]
        optimizer = optim.Adam(lambda_params, lr=0.001)
        batch_size = self.config.training.fval_batch_size
        val_loader = DataLoader(
            val_data,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            collate_fn=collate_fn,
        )
        test_loader = DataLoader(
            test_data,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            collate_fn=collate_fn,
        )
        dkef.save_alpha_matrices(
            train_data, collate_fn, self.config.device, override=True
        )

        def energy_net(inputs):
            return -dkef(None, inputs, stage="finalize")

        step = 0
        while step < 1000:
            for val_batch in val_loader:
                if step >= 1000:
                    break
                val_batch = val_batch.to(self.config.device)

                if train_mode == "exact":
                    val_loss = exact_score_matching(energy_net, val_batch, train=True)
                elif train_mode == "sliced":
                    val_loss, _, _ = single_sliced_score_matching(energy_net, val_batch)
                elif train_mode == "sliced_fd":
                    val_loss = efficient_score_matching_conjugate(energy_net, val_batch)
                elif train_mode == "sliced_VR":
                    val_loss, _, _ = sliced_VR_score_matching(energy_net, val_batch)
                elif train_mode == "dsm":
                    val_loss = dsm(energy_net, val_batch, sigma=self.dsm_sigma)
                elif train_mode == "dsm_fd":
                    val_loss = dsm_fd(energy_net, val_batch, sigma=self.dsm_sigma)
                elif train_mode == "kingma":
                    logp, grad1, grad2 = dkef.approx_bp_forward(
                        None, val_batch, stage="finalize", mode=train_mode
                    )
                    val_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1)
                elif train_mode == "CP":
                    logp, grad1, S_r, S_i = dkef.approx_bp_forward(
                        None, val_batch, stage="finalize", mode=train_mode
                    )
                    grad2 = S_r ** 2 - S_i ** 2
                    val_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1)

                val_loss = val_loss.mean()
                optimizer.zero_grad()
                val_loss.backward()
                tn = nn.utils.clip_grad_norm_(dkef.parameters(), 1.0)
                optimizer.step()
                logging.info("Val loss: {:.3f}".format(val_loss))
                tb_logger.add_scalar("finalize/loss", val_loss, global_step=step)
                step += 1

        val_losses = []
        for data_v in val_loader:
            data_v = data_v.to(self.config.device)
            batch_val_loss = exact_score_matching(energy_net, data_v, train=False)
            val_losses.append(batch_val_loss.mean())
        val_loss = sum(val_losses) / len(val_losses)
        logging.info("Overall val exact score matching: {:.3f}".format(val_loss))
        tb_logger.add_scalar("finalize/final_valid_score", val_loss, global_step=0)
        self.results["final_valid_score"] = np.asscalar(val_loss.cpu().numpy())

        test_losses = []
        for data_t in test_loader:
            data_t = data_t.to(self.config.device)
            batch_test_loss = exact_score_matching(energy_net, data_t, train=False)
            test_losses.append(batch_test_loss.mean())
        test_loss = sum(test_losses) / len(test_losses)
        logging.info("Overall test exact score matching: {:.3f}".format(test_loss))
        tb_logger.add_scalar("finalize/final_test_score", test_loss, global_step=0)
        self.results["final_test_score"] = np.asscalar(test_loss.cpu().numpy())
    def train_stage1(
        self, dkef, tb_logger, train_data, val_data, collate_fn, train_mode
    ):
        optimizer = self.get_optimizer(dkef.parameters())

        step = 0
        num_mb = len(train_data) // self.config.training.batch_size
        split_size = self.config.training.batch_size // 2
        best_val_step = 0
        best_val_loss = 1e5
        best_model = None
        train_losses = np.zeros(30)
        val_loss_window = np.zeros(15)
        torch.cuda.synchronize()
        prev_time = time.time()

        val_batch_size = len(val_data)
        num_val_iters = 1
        val_loader = DataLoader(
            val_data,
            batch_size=val_batch_size,
            shuffle=True,
            num_workers=2,
            collate_fn=collate_fn,
        )
        train_loader = DataLoader(
            train_data,
            batch_size=split_size,
            shuffle=True,
            num_workers=2,
            collate_fn=collate_fn,
        )
        train_iter = iter(train_loader)
        val_iter = iter(val_loader)
        total_time = 0.0
        time_dur = 0.0
        secs_per_it = []
        for _ in range(self.config.training.n_epochs):
            for _ in range(num_mb):
                train_iter, X_t = self.sample(train_iter, train_loader)
                train_iter, X_v = self.sample(train_iter, train_loader)

                start_point = time.time()

                def energy_net(inputs):
                    return -dkef(X_t, inputs)

                if train_mode == "exact":
                    train_loss = exact_score_matching(energy_net, X_v, train=True)
                elif train_mode == "sliced":
                    train_loss, _, _ = single_sliced_score_matching(energy_net, X_v)
                elif train_mode == "sliced_fd":
                    train_loss = efficient_score_matching_conjugate(energy_net, X_v)
                elif train_mode == "sliced_VR":
                    train_loss, _, _ = sliced_VR_score_matching(energy_net, X_v)
                elif train_mode == "dsm":
                    train_loss = dsm(energy_net, X_v, sigma=self.dsm_sigma)
                elif train_mode == "dsm_fd":
                    train_loss = dsm_fd(energy_net, X_v, sigma=self.dsm_sigma)
                elif train_mode == "kingma":
                    logp, grad1, grad2 = dkef.approx_bp_forward(
                        X_t, X_v, stage="train", mode=train_mode
                    )
                    train_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1)
                elif train_mode == "CP":
                    logp, grad1, S_r, S_i = dkef.approx_bp_forward(
                        X_t, X_v, stage="train", mode=train_mode
                    )
                    grad2 = S_r ** 2 - S_i ** 2
                    train_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1)

                train_loss = train_loss.mean()
                optimizer.zero_grad()
                train_loss.backward()
                train_losses[step % 30] = train_loss.detach()

                # Their code clips by overall gradient norm at 100.
                tn = nn.utils.clip_grad_norm_(dkef.parameters(), 1.0)
                optimizer.step()
                time_dur += time.time() - start_point

                idx = np.random.choice(len(train_data), 1000, replace=False)
                train_data_for_val = torch.utils.data.Subset(train_data, idx)
                dkef.save_alpha_matrices(
                    train_data_for_val, collate_fn, self.config.device
                )

                # Compute validation loss
                def energy_net_val(inputs):
                    return -dkef(None, inputs, stage="eval")

                val_losses = []
                for val_step in range(num_val_iters):
                    val_iter, data_v = self.sample(val_iter, val_loader)
                    if train_mode == "exact":
                        batch_val_loss = exact_score_matching(
                            energy_net_val, data_v, train=False
                        )
                    elif train_mode == "sliced":
                        batch_val_loss, _, _ = single_sliced_score_matching(
                            energy_net_val, data_v, detach=True
                        )
                    elif train_mode == "sliced_fd":
                        batch_val_loss = efficient_score_matching_conjugate(
                            energy_net_val, data_v, detach=True
                        )
                    elif train_mode == "sliced_VR":
                        batch_val_loss, _, _ = sliced_VR_score_matching(
                            energy_net_val, data_v, detach=True
                        )
                    elif train_mode == "dsm":
                        batch_val_loss = dsm(
                            energy_net_val, data_v, sigma=self.dsm_sigma
                        )
                    elif train_mode == "dsm_fd":
                        batch_val_loss = dsm_fd(
                            energy_net_val, data_v, sigma=self.dsm_sigma
                        )
                    elif train_mode == "kingma":
                        logp, grad1, grad2 = dkef.approx_bp_forward(
                            None, X_v, stage="eval", mode=train_mode
                        )
                        batch_val_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1)
                    elif train_mode == "CP":
                        logp, grad1, S_r, S_i = dkef.approx_bp_forward(
                            None, X_v, stage="eval", mode=train_mode
                        )
                        grad2 = S_r ** 2 - S_i ** 2
                        batch_val_loss = (0.5 * grad1 ** 2).sum(1) + grad2.sum(1)

                    val_losses.append(batch_val_loss.mean())

                val_loss = sum(val_losses) / len(val_losses)
                val_loss_window[step % 15] = val_loss.detach()
                smoothed_val_loss = (
                    val_loss_window[: step + 1].mean()
                    if step < 15
                    else val_loss_window.mean()
                )

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_val_step = step
                    best_model = copy.deepcopy(dkef.state_dict())
                elif step - best_val_step > self.config.training.patience:
                    self.results["secs_per_it"] = sum(secs_per_it) / len(secs_per_it)
                    self.results["its_per_sec"] = 1.0 / self.results["secs_per_it"]
                    logging.info(
                        "Validation loss has not improved in {} steps. Finalizing model!".format(
                            self.config.training.patience
                        )
                    )
                    return best_model

                mean_train_loss = (
                    train_losses[: step + 1].mean()
                    if step < 30
                    else train_losses.mean()
                )
                logging.info(
                    "Step {}, Training loss: {:.2f}, validation loss: {:.2f}".format(
                        step, mean_train_loss, best_val_loss
                    )
                )
                tb_logger.add_scalar(
                    "train/train_loss_smoothed", mean_train_loss, global_step=step
                )
                tb_logger.add_scalar(
                    "train/best_val_loss", best_val_loss, global_step=step
                )
                tb_logger.add_scalar("train/train_loss", train_loss, global_step=step)
                tb_logger.add_scalar("train/val_loss", val_loss, global_step=step)

                if step % 20 == 0:
                    torch.cuda.synchronize()
                    new_time = time.time()
                    logging.info("#" * 80)
                    if step > 0:
                        secs_per_it.append((new_time - prev_time) / 20.0)
                    logging.info(
                        "Iterations per second: {:.3f}".format(
                            20.0 / (new_time - prev_time)
                        )
                    )
                    logging.info("Only Training Time: {:.3f}".format(time_dur))
                    time_dur = 0.0
                    tb_logger.add_scalar(
                        "train/its_per_sec",
                        20.0 / (new_time - prev_time),
                        global_step=step,
                    )

                    if step > 0:
                        total_time += new_time - prev_time

                    val_losses_exact = []
                    for val_step in range(num_val_iters):
                        val_iter, data_v = self.sample(val_iter, val_loader)
                        vle = exact_score_matching(energy_net_val, data_v, train=False)
                        val_losses_exact.append(vle.mean())

                    val_loss_exact = sum(val_losses_exact) / len(val_losses_exact)
                    logging.info(
                        "Exact score matching loss on val: {:.2f}".format(
                            val_loss_exact.mean()
                        )
                    )
                    tb_logger.add_scalar(
                        "eval/exact_score_matching",
                        val_loss_exact.mean(),
                        global_step=step,
                    )

                    logging.info("#" * 80)
                    torch.cuda.synchronize()
                    prev_time = time.time()
                step += 1

        logging.info("Completed training")
        self.results["secs_per_it"] = sum(secs_per_it) / len(secs_per_it)
        self.results["its_per_sec"] = 1.0 / self.results["secs_per_it"]

        return best_model
    def train(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True,
                              transform=transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True,
                                   transform=transform)
        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True,
                            transform=transform)
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, val_indices = indices[:int(num_items * 0.9)], indices[int(num_items * 0.9):]
            val_dataset = Subset(dataset, val_indices)
            dataset = Subset(dataset, train_indices)
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True,
                                 transform=transform)

        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=2)
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=2)

        val_iter = iter(val_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)
        model_path = os.path.join(self.args.run, 'results', self.args.doc)
        if os.path.exists(model_path):
            shutil.rmtree(model_path)
        os.makedirs(model_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)

        flow = NICE(self.config.input_dim, self.config.model.hidden_size, self.config.model.num_layers).to(
            self.config.device)

        optimizer = self.get_optimizer(flow.parameters())

        # Set up test data
        noise_sigma = self.config.data.noise_sigma
        step = 0

        def energy_net(inputs):
            energy, _ = flow(inputs, inv=False)
            return -energy

        def grad_net_kingma(inputs):
            energy, _ = flow(inputs, inv=False)
            grad1, grad2 = flow.grads_backward(inv=False)
            return -grad1, -grad2

        def grad_net_UT(inputs):
            energy, _ = flow(inputs, inv=False)
            grad1, T, U = flow.grads_backward_TU(inv=False)
            grad2 = T * U / 2.
            return -grad1, -grad2

        def grad_net_S(inputs):
            energy, _ = flow(inputs, inv=False)
            grad1, S_r, S_i = flow.grads_backward_S(inv=False)
            grad2 = (S_r ** 2 - S_i ** 2)
            return -grad1, -grad2

        def sample_net(z):
            samples, _ = flow(z, inv=True)
            samples, _ = Logit()(samples, mode='inverse')
            return samples

        # Use this to select the sigma for DSM losses
        if self.config.training.algo == 'dsm':
            sigma = self.args.dsm_sigma
            # if noise_sigma is None:
            #     sigma = select_sigma(iter(dataloader), iter(val_loader))
            # else:
            #     sigma = select_sigma(iter(dataloader), iter(val_loader), noise_sigma=noise_sigma)

        if self.args.load_path != "":
            flow.load_state_dict(torch.load(self.args.load_path))

        best_model = {"val": None, "ll": None, "esm": None}
        best_val_loss = {"val": 1e+10, "ll": -1e+10, "esm": 1e+10}
        best_val_iter = {"val": 0, "ll": 0, "esm": 0}

        for _ in range(self.config.training.n_epochs):
            for _, (X, y) in enumerate(dataloader):
                X = X + (torch.rand_like(X) - 0.5) / 256.
                flattened_X = X.type(torch.float32).to(self.config.device).view(X.shape[0], -1)
                flattened_X.clamp_(1e-3, 1-1e-3)
                flattened_X, _ = Logit()(flattened_X, mode='direct')

                if noise_sigma is not None:
                    flattened_X += torch.randn_like(flattened_X) * noise_sigma

                flattened_X.requires_grad_(True)

                logp = -energy_net(flattened_X)

                logp = logp.mean()

                if self.config.training.algo == 'kingma':
                    loss = approx_backprop_score_matching(grad_net_kingma, flattened_X)
                if self.config.training.algo == 'UT':
                    loss = approx_backprop_score_matching(grad_net_UT, flattened_X)
                if self.config.training.algo == 'S':
                    loss = approx_backprop_score_matching(grad_net_S, flattened_X)
                elif self.config.training.algo == 'mle':
                    loss = -logp
                elif self.config.training.algo == 'ssm':
                    loss, *_ = single_sliced_score_matching(energy_net, flattened_X, noise_type=self.config.training.noise_type)
                elif self.config.training.algo == 'ssm_vr':
                    loss, *_ = sliced_VR_score_matching(energy_net, flattened_X, noise_type=self.config.training.noise_type)
                elif self.config.training.algo == 'dsm':
                    loss = dsm(energy_net, flattened_X, sigma=sigma)
                elif self.config.training.algo == "exact":
                    loss = exact_score_matching(energy_net, flattened_X, train=True).mean()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()


                if step % 10 == 0:
                    try:
                        val_X, _ = next(val_iter)
                    except:
                        val_iter = iter(val_loader)
                        val_X, _ = next(val_iter)

                    val_X = val_X + (torch.rand_like(val_X) - 0.5) / 256.
                    val_X = val_X.type(torch.float32).to(self.config.device)
                    val_X.clamp_(1e-3, 1-1e-3)
                    val_X, _ = Logit()(val_X, mode='direct')
                    val_X = val_X.view(val_X.shape[0], -1)
                    if noise_sigma is not None:
                        val_X += torch.randn_like(val_X) * noise_sigma

                    val_logp = -energy_net(val_X).mean()
                    if self.config.training.algo == 'kingma':
                        val_loss = approx_backprop_score_matching(grad_net_kingma, val_X)
                    if self.config.training.algo == 'UT':
                        val_loss = approx_backprop_score_matching(grad_net_UT, val_X)
                    if self.config.training.algo == 'S':
                        val_loss = approx_backprop_score_matching(grad_net_S, val_X)
                    elif self.config.training.algo == 'ssm':
                        val_loss, *_ = single_sliced_score_matching(energy_net, val_X, noise_type=self.config.training.noise_type)
                    elif self.config.training.algo == 'ssm_vr':
                        val_loss, *_ = sliced_VR_score_matching(energy_net, val_X, noise_type=self.config.training.noise_type)
                    elif self.config.training.algo == 'dsm':
                        val_loss = dsm(energy_net, val_X, sigma=sigma)
                    elif self.config.training.algo == 'mle':
                        val_loss = -val_logp
                    elif self.config.training.algo == "exact":
                        val_loss = exact_score_matching(energy_net, val_X, train=False).mean()

                    logging.info("logp: {:.3f}, val_logp: {:.3f}, loss: {:.3f}, val_loss: {:.3f}".format(logp.item(),
                                                                                           val_logp.item(),
                                                                                           loss.item(),
                                                                                           val_loss.item()))
                    tb_logger.add_scalar('logp', logp, global_step=step)
                    tb_logger.add_scalar('loss', loss, global_step=step)
                    tb_logger.add_scalar('val_logp', val_logp, global_step=step)
                    tb_logger.add_scalar('val_loss', val_loss, global_step=step)

                    if val_loss < best_val_loss['val']:
                        best_val_loss['val'] = val_loss
                        best_val_iter['val'] = step
                        best_model['val'] = copy.deepcopy(flow.state_dict())
                    if val_logp > best_val_loss['ll']:
                        best_val_loss['ll'] = val_logp
                        best_val_iter['ll'] = step
                        best_model['ll'] = copy.deepcopy(flow.state_dict())

                if step % 100 == 0:
                    with torch.no_grad():
                        z = torch.normal(torch.zeros(100, flattened_X.shape[1], device=self.config.device))
                        samples = sample_net(z)
                        samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                               self.config.data.image_size)
                        samples = torch.clamp(samples, 0.0, 1.0)
                        image_grid = make_grid(samples, 10)
                        tb_logger.add_image('samples', image_grid, global_step=step)
                        data = X
                        data_grid = make_grid(data[:100], 10)
                        tb_logger.add_image('data', data_grid, global_step=step)

                    logging.info("Computing exact score matching....")
                    try:
                        val_X, _ = next(val_iter)
                    except:
                        val_iter = iter(val_loader)
                        val_X, _ = next(val_iter)

                    val_X = val_X + (torch.rand_like(val_X) - 0.5) / 256.
                    val_X = val_X.type(torch.float32).to(self.config.device)
                    val_X.clamp_(1e-3, 1-1e-3)
                    val_X, _ = Logit()(val_X, mode='direct')
                    val_X = val_X.view(val_X.shape[0], -1)
                    if noise_sigma is not None:
                        val_X += torch.randn_like(val_X) * noise_sigma

                    sm_loss = exact_score_matching(energy_net, val_X, train=False).mean()
                    if sm_loss < best_val_loss['esm']:
                        best_val_loss['esm'] = sm_loss
                        best_val_iter['esm'] = step
                        best_model['esm'] = copy.deepcopy(flow.state_dict())

                    logging.info('step: {}, exact score matching loss: {}'.format(step, sm_loss.item()))
                    tb_logger.add_scalar('exact_score_matching_loss', sm_loss, global_step=step)

                if step % 500 == 0:
                    torch.save(flow.state_dict(), os.path.join(model_path, 'nice.pth'))

                step += 1

        self.results = {}
        self.evaluate_model(flow.state_dict(), "final", val_loader, test_loader, model_path)
        self.evaluate_model(best_model['val'], "best_on_val", val_loader, test_loader, model_path)
        self.evaluate_model(best_model['ll'], "best_on_ll", val_loader, test_loader, model_path)
        self.evaluate_model(best_model['esm'], "best_on_esm", val_loader, test_loader, model_path)
        self.results['final']['num_iters'] = step
        self.results['best_on_val']['num_iters'] = best_val_iter['val']
        self.results['best_on_ll']['num_iters'] = best_val_iter['ll']
        self.results['best_on_esm']['num_iters'] = best_val_iter['esm']

        pickle_out = open(model_path + "/results.pkl", "wb")
        pickle.dump(self.results, pickle_out)
        pickle_out.close()
Beispiel #5
0
def transfer(args, config):
    """
    once an icebeem is pretrained on some labels (0-7), we train only secondary network (g in our manuscript)
    on unseen labels 8-9 (these are new datasets)
    """
    conditional = args.subset_size != 0
    # load data
    dataloader, dataset, cond_size = get_dataset(args,
                                                 config,
                                                 test=False,
                                                 rev=True,
                                                 one_hot=True,
                                                 subset=True)
    # load the feature network f
    ckpt_path = os.path.join(args.checkpoints, 'checkpoint.pth')
    print('loading weights from: {}'.format(ckpt_path))
    states = torch.load(ckpt_path, map_location=config.device)
    f = feature_net(config).to(config.device)
    f.load_state_dict(states[0])
    if conditional:
        # define the feature network g
        g = SimpleLinear(cond_size, f.output_size,
                         bias=False).to(config.device)
        energy_net = ModularUnnormalizedConditionalEBM(
            f, g, augment=config.model.augment, positive=config.model.positive)
        # define the optimizer
        parameters = energy_net.g.parameters()
        optimizer = get_optimizer(config, parameters)
    else:
        # no learning is involved: just evaluate f on the new labels, with g = 1
        energy_net = ModularUnnormalizedEBM(f)
        optimizer = None
    # start optimizing!
    eCount = 10
    loss_track_epochs = []
    for epoch in range(eCount):
        print('epoch: ' + str(epoch))
        loss_track = []
        for i, (X, y) in enumerate(dataloader):
            X = X.to(config.device)
            X = X / 256. * 255. + torch.rand_like(X) / 256.
            if conditional:
                loss = cdsm(energy_net, X, y, sigma=0.01)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            else:
                # just evaluate the DSM loss using the pretarined f --- no learning
                loss = dsm(energy_net, X, sigma=0.01)
                loss.backward(
                )  # strangely, without this line, the script requires twice as much GPU memory
            loss_track.append(loss.item())
            loss_track_epochs.append(loss.item())

        pickle.dump(
            loss_track,
            open(
                os.path.join(
                    args.output,
                    'size{}_seed{}.p'.format(args.subset_size, args.seed)),
                'wb'))
    print('saving loss track under: {}'.format(args.output))
    pickle.dump(
        loss_track_epochs,
        open(
            os.path.join(
                args.output,
                'all_epochs_SIZE{}_SEED{}.p'.format(args.subset_size,
                                                    args.seed)), 'wb'))
Beispiel #6
0
def train(args, config, conditional=True):
    save_weights = 'baseline' not in config.data.dataset.lower(
    )  # we don't need the
    if args.subset_size == 0:
        conditional = False
    # load dataset
    dataloader, dataset, cond_size = get_dataset(args, config, one_hot=True)
    # define the energy model
    if conditional:
        f = feature_net(config).to(config.device)
        g = SimpleLinear(cond_size, f.output_size,
                         bias=False).to(config.device)
        energy_net = ModularUnnormalizedConditionalEBM(
            f, g, augment=config.model.augment, positive=config.model.positive)
    else:
        f = feature_net(config).to(config.device)
        energy_net = ModularUnnormalizedEBM(f)
    # get optimizer
    optimizer = get_optimizer(config, energy_net.parameters())
    # train
    step = 0
    loss_track_epochs = []
    for epoch in range(config.training.n_epochs):
        loss_track = []
        for i, (X, y) in enumerate(dataloader):
            step += 1
            energy_net.train()
            X = X.to(config.device)
            X = X / 256. * 255. + torch.rand_like(X) / 256.
            if config.data.logit_transform:
                X = logit_transform(X)
            # compute loss
            if conditional:
                loss = cdsm(energy_net, X, y, sigma=0.01)
            else:
                loss = dsm(energy_net, X, sigma=0.01)
            # optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_track.append(loss.item())
            loss_track_epochs.append(loss.item())

            if step >= config.training.n_iters and save_weights:
                enet, energy_net_finalLayer = energy_net.f, energy_net.g
                # save final checkpoints for distribution!
                states = [
                    enet.state_dict(),
                    optimizer.state_dict(),
                ]
                print('saving weights under: {}'.format(args.checkpoints))
                # torch.save(states, os.path.join(args.checkpoints, 'checkpoint_{}.pth'.format(step)))
                torch.save(states,
                           os.path.join(args.checkpoints, 'checkpoint.pth'))
                torch.save([energy_net_finalLayer],
                           os.path.join(args.checkpoints,
                                        'finalLayerweights_.pth'))
                pickle.dump(
                    energy_net_finalLayer,
                    open(os.path.join(args.checkpoints, 'finalLayerweights.p'),
                         'wb'))
                return 0

            if step % config.training.snapshot_freq == 0:
                enet, energy_net_finalLayer = energy_net.f, energy_net.g
                print('checkpoint at step: {}'.format(step))
                # save checkpoint for transfer learning! !
                # torch.save([energy_net_finalLayer], os.path.join(args.log, 'finalLayerweights_.pth'))
                # pickle.dump(energy_net_finalLayer,
                #             open(os.path.join(args.log, 'finalLayerweights.p'), 'wb'))
                # states = [
                #     enet.state_dict(),
                #     optimizer.state_dict(),
                # ]
                # torch.save(states, os.path.join(args.log, 'checkpoint_{}.pth'.format(step)))
                # torch.save(states, os.path.join(args.log, 'checkpoint.pth'))

        if config.data.dataset.lower() in [
                'mnist_transferbaseline', 'cifar10_transferbaseline',
                'fashionmnist_transferbaseline', 'cifar100_transferbaseline'
        ]:
            # save loss track during epoch for transfer baseline
            pickle.dump(
                loss_track,
                open(
                    os.path.join(
                        args.output,
                        'size{}_seed{}.p'.format(args.subset_size, args.seed)),
                    'wb'))

    if config.data.dataset.lower() in [
            'mnist_transferbaseline', 'cifar10_transferbaseline',
            'fashionmnist_transferbaseline', 'cifar100_transferbaseline'
    ]:
        # save loss track during epoch for transfer baseline
        print('saving loss track under: {}'.format(args.output))
        pickle.dump(
            loss_track_epochs,
            open(
                os.path.join(
                    args.output, 'all_epochs_SIZE{}_SEED{}.p'.format(
                        args.subset_size, args.seed)), 'wb'))

    # save final checkpoints for distrubution!
    if save_weights:
        enet, energy_net_finalLayer = energy_net.f, energy_net.g
        states = [
            enet.state_dict(),
            optimizer.state_dict(),
        ]
        print('saving weights under: {}'.format(args.checkpoints))
        # torch.save(states, os.path.join(args.checkpoints, 'checkpoint_{}.pth'.format(step)))
        torch.save(states, os.path.join(args.checkpoints, 'checkpoint.pth'))
        torch.save([energy_net_finalLayer],
                   os.path.join(args.checkpoints, 'finalLayerweights_.pth'))
        pickle.dump(
            energy_net_finalLayer,
            open(os.path.join(args.checkpoints, 'finalLayerweights.p'), 'wb'))