def compute_update_dir(self, p_d):
        self.gen.eval()
        self.dis.eval()
        self.classifier.eval()
        i = 0

        _dir = 0
        with torch.no_grad():
            while i < 10:
                true_data, _ = p_d.next()
                true_data = tensor2Var(true_data)

                noise = create_noise(true_data.size(0), self.args.noise_size)
                noise = tensor2Var(noise)

                gen_data = self.gen(noise).detach()

                true_out = self.dis(self.classifier(true_data, 'feat'), 'critic')
                gen_out = self.dis(self.classifier(gen_data, 'feat'), 'critic')

                _dir += (-true_out.mean() + gen_out.mean() - 1 * 2 * self.ps.alpha).item()

                # self.ps.update_dir(-1 * ((-real_output_c.mean() + fake_output_c.mean()).item() - \
                #     0.5 * 2 * self.ps.alpha))
                i += 1

            print(true_out.mean().item(), gen_out.mean().item())

        self.gen.train()
        self.dis.train()
        self.classifier.train()

        return -1 * np.sign(_dir)
示例#2
0
    def construct_p_d(self, data_1, data_2):
        beta = self.beta_2
        beta_1 = self.beta_1
        beta_2 = self.beta_2

        if self._type == 'normal':
            noise = torch.FloatTensor(data_1.size()).normal_(0, beta)

        elif self._type == 'uniform':
            noise = torch.FloatTensor(data_1.size()).uniform_(-beta, beta)

        elif self._type == 'inter':
            if len(data_1.size()) == 2:
                _beta = torch.FloatTensor(data_1.size(0),
                                          1).uniform_(beta_1, beta_2)
            elif len(data_1.size()) == 4:
                _beta = torch.FloatTensor(data_1.size(0), 1, 1,
                                          1).uniform_(beta_1, beta_2)
            _beta = tensor2Var(_beta)

            out = _beta * data_1 + (1 - _beta) * data_2

            return out

        elif self._type == 'huge_normal':
            noise = torch.randn(data_1.size()) * (1 + beta)

        return data_1 + tensor2Var(noise)
    def train_c(self, train_loader, semi_weight):
        args = self.args
        # set_require_grad(self.classifier, requires_grad=True)
        # standard classification loss
        lab_data, lab_labels = train_loader.next()
        lab_data, lab_labels = tensor2Var(lab_data), tensor2Var(lab_labels)

        noise = create_noise(lab_data.size(0), args.noise_size)
        noise = tensor2Var(noise)

        gen_data = self.gen(noise).detach()

        lab_logits = self.classifier(lab_data, 'class')
        gen_logits = self.classifier(gen_data, 'class')

        lab_loss = F.cross_entropy(lab_logits, lab_labels)

        gen_prob = F.softmax(gen_logits, dim=1)

        entropy = -(gen_prob * torch.log(gen_prob + 1e-8)).sum(1).mean()

        c_loss = lab_loss + semi_weight * entropy

        self.classifier_opt.zero_grad()
        c_loss.backward()
        self.classifier_opt.step()

        return lab_loss.cpu().item(), entropy.cpu().item(), c_loss.cpu().item()
    def eval(self, data_loader):
        self.gen.eval()
        self.dis.eval()
        self.classifier.eval()

        loss, incorrect, cnt = 0, 0, 0
        total_num = 0
        max_unl_acc, max_gen_acc = 0, 0
        with torch.no_grad():
            for i, (data, labels) in enumerate(data_loader.get_iter()):
                data, labels = tensor2Var(data), tensor2Var(labels)

                noise = create_noise(data.size(0), self.args.noise_size)
                noise = tensor2Var(noise)

                gen_data = self.gen(noise).detach()
                gen_logits = self.classifier(gen_data, 'class')

                pred_logits = self.classifier(data, 'class')
                labels = labels.view(-1)
                loss += F.cross_entropy(pred_logits, labels).item() * data.size(0)
                cnt += 1
                total_num += data.size(0)
                incorrect += torch.ne(torch.max(pred_logits, 1)[1], labels).float().sum().item()

                max_unl_acc += torch.sum(pred_logits.max(1)[0].detach().gt(0.0).float()).item()
                max_gen_acc += torch.sum(gen_logits.max(1)[0].detach().lt(0.0).float()).item()

        return loss / total_num, incorrect, total_num, max_unl_acc / total_num, max_gen_acc / total_num
def calc_gradient_penalty(net, real_data, fake_data):
    alpha = torch.FloatTensor(real_data.size(0), 1, 1, 1).uniform_(0, 1)

    alpha = tensor2Var(alpha)

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    interpolates.requires_grad_(True)

    disc_interpolates = net(interpolates, 'critic')

    ones = tensor2Var(torch.ones(disc_interpolates.size()))

    gradients = grad(outputs=disc_interpolates,
                     inputs=interpolates,
                     grad_outputs=ones,
                     create_graph=True,
                     retain_graph=True,
                     only_inputs=True)[0]

    while len(gradients.size()) > 1:
        gradients = gradients.norm(2, dim=(len(gradients.size()) - 1))

    gradient_penalty = ((gradients - 1.0)**2).mean()

    return gradient_penalty
示例#6
0
    def get_data(self):
        data_1, _ = self.p_d_1.next()
        data_2, _ = self.p_d_2.next()

        data_1 = tensor2Var(data_1)
        data_2 = tensor2Var(data_2)

        return data_1, data_2
    def train_c(self, labeled_loader, unlabeled_loader):
        args = self.args
        set_require_grad(self.classifier, requires_grad=True)
        # standard classification loss
        lab_data, lab_labels = labeled_loader.next()
        lab_data, lab_labels = tensor2Var(lab_data), tensor2Var(lab_labels)

        lab_labels = lab_labels.view(-1)
        unl_data, _ = unlabeled_loader.next()
        unl_data = tensor2Var(unl_data)

        noise = create_noise(unl_data.size(0), args.noise_size)
        noise = tensor2Var(noise)

        gen_data = self.gen(noise).detach()

        lab_logits = self.classifier(lab_data, 'class')
        unl_logits = self.classifier(unl_data, 'class')
        gen_logits = self.classifier(gen_data, 'class')

        lab_loss = F.cross_entropy(lab_logits, lab_labels)

        unl_logsumexp = log_sum_exp(unl_logits)
        gen_logsumexp = log_sum_exp(gen_logits)

        unl_acc = torch.mean(torch.sigmoid(unl_logsumexp.detach()).gt(0.5).float())
        gen_acc = torch.mean(torch.sigmoid(gen_logsumexp.detach()).lt(0.5).float())

        # This is the typical GAN cost, where sumexp(logits) is seen as the input to the sigmoid
        true_loss = - 0.5 * torch.mean(unl_logsumexp) + 0.5 * torch.mean(F.softplus(unl_logsumexp))
        fake_loss = 0.5 * torch.mean(F.softplus(gen_logsumexp))

        # max_unl_acc = torch.mean(unl_logits.max(1)[0].detach().gt(0.0).float())
        # max_gen_acc = torch.mean(gen_logits.max(1)[0].detach().gt(0.0).float())

        unl_prob = F.softmax(unl_logits, dim=1)

        entropy = -(unl_prob * torch.log(unl_prob + 1e-8)).sum(1).mean()

        unl_loss = true_loss + fake_loss
         
        c_loss = lab_loss + args.lambda_gan * unl_loss + args.lambda_e * entropy

        if args.lambda_consistency > 0:
            unl_logits_2 = self.classifier(unl_data, 'class')
            unl_prob_2 = F.softmax(unl_logits_2, dim=1)
            consistency_loss = ((unl_prob - unl_prob_2) ** 2).mean()

            c_loss += args.lambda_consistency * consistency_loss

            if self.total_iter % 1000 == 0:
                print(consistency_loss)

        self.classifier_opt.zero_grad()
        c_loss.backward()
        self.classifier_opt.step()

        return lab_loss.cpu().item(), unl_loss.cpu().item(), entropy.cpu().item()
 def reparameterize(self, mean_input, logvar_input):
     if self.training:
         std = torch.exp(0.5 * logvar_input)
         epsilon = tensor2Var(torch.randn(std.size()))
         return mean_input + std * epsilon
     else:
         return mean_input
    def visualize(self, train_loader):
        self.gen.eval()
        self.dis.eval()
        self.vae.eval()

        vis_size = 100
        for i, (data, _) in enumerate(train_loader.get_iter()):
            data = tensor2Var(data)
            with torch.no_grad():
                feat = self.vae.get_features(data)
                gen_images = self.gen(feat)
                gen_images = self.vae.decode(gen_images)

            break

        save_path = os.path.join(self.args.log_folder,
                                 '%d_gen_images.png' % self.total_iter)

        if self.args.dataset == 'mnist':
            # gen_images = gen_images.view(-1, self.args.n_channels, self.args.image_size, self.args.image_size)
            gen_images = gen_images * 0.5 + 0.5
            # print(gen_images.shape)
        elif self.args.dataset == 'svhn' or self.args.dataset == 'cifar':
            gen_images = gen_images * 0.5 + 0.5
        else:
            raise NotImplementedError

        vutils.save_image(gen_images.data.cpu(), save_path, nrow=10)

        save_path = os.path.join(self.args.log_folder,
                                 '%d_ori_images.png' % self.total_iter)

        vutils.save_image(data.data.cpu() * 0.5 + 0.5, save_path, nrow=10)
        self.vae.train()
    def train_g(self):
        args = self.args
        set_require_grad(self.dis, False)

        noise = create_noise(args.train_batch_size, args.noise_size)
        noise = tensor2Var(noise)

        gen_data = self.gen(noise)

        # get the feature of generated data
        gen_data = self.classifier(gen_data, 'feat')

        pullaway = pullaway_loss(gen_data)

        gen_out = self.dis(gen_data, 'critic')

        gen_loss = -gen_out.mean()

        g_loss = gen_loss + args.lambda_p * pullaway
        
        self.gen_opt.zero_grad()
        g_loss.backward()
        self.gen_opt.step()

        return gen_loss.cpu().item()
    def train_g(self, p_d_2):
        args = self.args
        set_require_grad(self.dis, False)

        true_data, _ = p_d_2.next()
        true_data = tensor2Var(true_data)

        # noise = create_noise(args.train_batch_size, args.noise_size)
        # noise = tensor2Var(noise)

        true_data = self.vae.get_features(true_data)

        gen_data = self.gen(true_data)

        # self.vae.eval()

        # gen_data = self.vae.get_features(gen_data)

        # self.vae.train()

        gen_out = self.dis(gen_data, 'critic')

        feat_loss = ((gen_data - true_data).view(gen_data.shape[0],
                                                 -1)**2).mean()

        gen_loss = -gen_out.mean() + args.lambda_feat * feat_loss

        g_loss = gen_loss

        self.gen_opt.zero_grad()
        g_loss.backward()
        self.gen_opt.step()

        return gen_loss.cpu().item(), feat_loss.cpu().item()
    def visualize_generation(self, _iter):
        self.vae.eval()
        noise = torch.randn(self.sample_size, self.feature_size, 1, 1)

        with torch.no_grad():
            noise_v = tensor2Var(noise)
            output = self.vae.decode(noise_v)

        tv.utils.save_image(
            output.data * 0.5 + 0.5,
            os.path.join(self.args.log_folder, 'generation_%d.png' % _iter))
    def eval(self, data_loader):
        self.gen.eval()
        self.dis.eval()
        self.classifier.eval()

        loss, incorrect, cnt = 0, 0, 0
        total_num = 0
        with torch.no_grad():
            for i, (data, labels) in enumerate(data_loader.get_iter()):
                data, labels = tensor2Var(data), tensor2Var(labels)

                pred_logits = self.classifier(data, 'class')
                loss += F.cross_entropy(pred_logits,
                                        labels).item() * data.size(0)
                cnt += 1
                total_num += data.size(0)
                incorrect += torch.ne(torch.max(pred_logits, 1)[1],
                                      labels).float().sum().item()

        return loss / total_num, incorrect, total_num
def pullaway_loss(x1):
    norm_x1 = F.normalize(x1)

    N = x1.size(0)
    cosine_similarity_matrix = torch.matmul(norm_x1, norm_x1.t())

    mask = torch.ones(cosine_similarity_matrix.size()) - torch.diag(torch.ones(N))
    mask_v = tensor2Var(mask)

    cosine_similarity_matrix = (cosine_similarity_matrix * mask_v) ** 2

    return cosine_similarity_matrix.sum() / (N * (N - 1))
    def param_init_cnn(self, labeled_loader):
        def func_gen(flag):
            def func(m):
                if hasattr(m, 'init_mode'):
                    setattr(m, 'init_mode', flag)
            return func

        images = []
        for i in range(ceil(500 / self.args.train_batch_size)):
            lab_images, _ = labeled_loader.next()
            images.append(lab_images)
        images = torch.cat(images, 0)

        self.gen.apply(func_gen(True))
        noise = tensor2Var(torch.Tensor(images.size(0), self.args.noise_size).uniform_())
        gen_images = self.gen(noise)
        self.gen.apply(func_gen(False))

        self.classifier.apply(func_gen(True))
        logits = self.classifier(tensor2Var(images))
        self.classifier.apply(func_gen(False))
def pullaway_loss_lp(x1, p=2):
    dist = torch.norm(x1[:, None] - x1, dim=2, p=p)

    dist = dist / dist.max()

    N = x1.size(0)

    mask = torch.ones(dist.size()) - torch.diag(torch.ones(N))
    mask_v = tensor2Var(mask)

    dist = dist * mask_v

    return dist.sum() / (N * (N - 1))
    def train_d(self, p_d, p_d_bar):
        args = self.args

        set_require_grad(self.dis, requires_grad=True)

        j = 0

        # train discriminator multiples times per generator iteration
        while j < args.iter_c:
            j += 1

            true_data_bar = p_d_bar.sample_feat(self.vae.get_features)
            true_data, _ = p_d.next()
            true_data = tensor2Var(true_data)

            true_data = self.vae.get_features(true_data)

            # noise = create_noise(true_data.size(0), args.noise_size)
            # noise = tensor2Var(noise)

            gen_data = self.gen(true_data).detach()

            # self.vae.eval()

            # gen_data = self.vae.get_features(gen_data)

            # self.vae.train()

            true_data_size = int(true_data.size(0) * self.ps.alpha)

            gen_size = true_data.size(0) - true_data_size

            # concatenate true and gen data
            true_gen_data = torch.cat(
                [true_data[:true_data_size], gen_data[:gen_size]], dim=0)

            true_data_bar_out = self.dis(true_data_bar, 'critic')
            true_gen_data_out = self.dis(true_gen_data, 'critic')

            dis_loss = -true_data_bar_out.mean() + true_gen_data_out.mean()

            d_loss = dis_loss + \
                args.lambda_g * calc_gradient_penalty(self.dis, true_data_bar, true_gen_data)

            self.dis_opt.zero_grad()
            d_loss.backward()
            self.dis_opt.step()

        return -dis_loss.cpu().item()
    def param_init_dnn(self, unlabeled_loader):
        def func_gen(flag):
            def func(m):
                if hasattr(m, 'init_mode'):
                    setattr(m, 'init_mode', flag)
            return func

        images = []
        for i in range(ceil(500 / self.args.train_batch_size)):
            unl_images, _ = unlabeled_loader.next()
            images.append(unl_images)
        images = torch.cat(images, 0)

        self.classifier.apply(func_gen(True))
        logits = self.classifier(tensor2Var(images))
        self.classifier.apply(func_gen(False))
    def visualize(self):
        self.gen.eval()
        self.dis.eval()

        vis_size = 100
        noise = create_noise(vis_size, self.args.noise_size)
        with torch.no_grad():
            noise = tensor2Var(noise)
            gen_images = self.gen(noise)

        save_path = os.path.join(self.args.log_folder, 'gen_images_%d.png' % self.total_iter)

        if self.args.dataset == 'mnist':
            gen_images = gen_images.view(-1, self.args.n_channels, self.args.image_size, self.args.image_size)
        elif self.args.dataset == 'svhn' or self.args.dataset == 'cifar':
            gen_images = gen_images * 0.5 + 0.5
        else:
            raise NotImplementedError

        vutils.save_image(gen_images.data.cpu(), save_path, nrow=10)
    def eval_gen(self, gen_num):
        self.gen.eval()
        self.dis.eval()
        self.classifier.eval()

        loss = 0
        total_num = 0
        batch_size = self.args.train_batch_size
        with torch.no_grad():
            while total_num < gen_num:
                if total_num + batch_size > gen_num:
                    batch_size = gen_num - total_num

                noise = create_noise(batch_size, self.args.noise_size)
                noise = tensor2Var(noise)
                gen_images = self.gen(noise)
                gen_logits = self.classifier(gen_images, 'class')
                gen_prob = F.softmax(gen_logits, dim=1)

                loss += -(gen_prob * torch.log(gen_prob + 1e-8)).sum().item()

                total_num += batch_size
        return loss / total_num
    def visualize_reconstruction(self, train_loader, _iter):
        self.vae.eval()

        with torch.no_grad():
            for i, (data, _) in enumerate(train_loader.get_iter()):
                data_v = tensor2Var(data)
                reconstruct, mean, _ = self.vae(data_v)

                # noise_in = tensor2Var(torch.FloatTensor(mean.size()).uniform_(-self.args.beta_1, self.args.beta_1))
                # noise_out = tensor2Var(torch.FloatTensor(mean.size()).uniform_(-self.args.beta_2, self.args.beta_2))

                # reconstruct = self.vae.decode(torch.clamp(mean + noise_in, -1.0, 1.0))

                # reconstruct_out = self.vae.decode(torch.clamp(mean + noise_out, -1.0, 1.0))

                break

        tv.utils.save_image(
            data[:self.sample_size] * 0.5 + 0.5,
            os.path.join(self.args.log_folder, '%d_origin.png' % _iter))

        tv.utils.save_image(
            reconstruct.data[:self.sample_size] * 0.5 + 0.5,
            os.path.join(self.args.log_folder, '%d_reconstruct.png' % _iter))
    def finetune(self, tr_data_dict):
        args = self.args

        train_loader = tr_data_dict['train_loader']
        p_d = tr_data_dict['p_d']
        ######################################################################
        ### start training

        total_iter = 0

        best_loss = 1e8
        best_err = 1e8
        best_err_per = 1e8
        begin_time = time()

        stop = 0

        # for p in self.vae.conv_feature.parameters():
        #     p.requires_grad = False
        # for p in self.vae.mean_layer.parameters():
        #     p.requires_grad = False
        # for p in self.vae.std_layer.parameters():
        #     p.requires_grad = False

        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.vae_opt, [300, 400, 500], 0.1)

        for epoch in range(args.max_epochs):
            epoch_ratio = float(epoch) / float(args.max_epochs)

            self.vae.train()
            for i, (lab_data, _) in enumerate(train_loader.get_iter()):
                lab_data = tensor2Var(lab_data)

                feat = self.vae.get_features(lab_data).detach()

                reconstruct = self.vae.decode(feat)

                gen_feat = self.gen(feat).detach()

                reconstruct_gen = self.vae.decode(gen_feat)

                gen_reconstruction_loss = torch.max(
                    tensor2Var(torch.zeros(gen_feat.shape[0])),
                    args.threshold - ((reconstruct_gen - lab_data)**2).view(
                        gen_feat.shape[0], -1).mean(1)).mean()

                # gen_reconstruction_loss = ((reconstruct_gen - lab_data) ** 2).mean()
                # reconstruction_loss = F.binary_cross_entropy(reconstruct, data_v, size_average=True)
                reconstruction_loss = ((reconstruct - lab_data)**2).mean()

                loss = reconstruction_loss + args.lambda_out * gen_reconstruction_loss

                self.vae_opt.zero_grad()
                loss.backward()
                self.vae_opt.step()

            scheduler.step()

            if args.save_vae:
                save_dict = {
                    'total_iter': total_iter,
                    'vae_state_dict': self.vae.state_dict(),
                    'vae_opt': self.vae_opt.state_dict()
                }

                torch.save(save_dict, self.vae_checkpoint)

            self.logger.info('epoch: %d, iter: %d, spent: %.3f s' %
                             (epoch, total_iter, time() - begin_time))
            self.logger.info(
                '[train] loss: %.4f, reconst loss: %.4f, gen_reconst_loss: %.4f'
                % (loss.cpu().item(), reconstruction_loss.cpu().item(),
                   gen_reconstruction_loss.cpu().item()))

            self.logger.info('--------')

            if epoch % 10 == 0:
                self.visualize_reconstruction(train_loader, epoch)
                self.visualize_generation(epoch)

            begin_time = time()

            total_iter += 1
            self.total_iter += 1
    def train_classifier(self, tr_data_dict):
        args = self.args
        set_require_grad(self.dis, requires_grad=False)
        set_require_grad(self.gen, requires_grad=False)
        set_require_grad(self.vae, requires_grad=False)

        self.gen.eval()
        self.dis.eval()
        self.vae.eval()

        train_loader = tr_data_dict['train_loader']
        p_d_2 = tr_data_dict['p_d_2']
        # valid_loader = tr_data_dict['valid_loader']

        total_iter = 0

        best_loss = 1e8
        best_err = 1e8
        best_err_per = 1e8
        begin_time = time()

        stop = 0

        self.visualize_embedding(p_d_2)

        for epoch in range(args.max_epochs):
            epoch_ratio = float(epoch) / float(args.max_epochs)
            # self.classifier_opt.param_groups[0]['lr'] = \
            #     max(args.min_lr, args.classifier_lr * max(0., min(3. * (1. - epoch_ratio), 1.)))

            self.classifier.train()

            for i, (lab_data,
                    lab_labels) in enumerate(train_loader.get_iter()):
                lab_data, lab_labels = tensor2Var(lab_data), tensor2Var(
                    lab_labels)

                noise = create_noise(lab_data.size(0), args.noise_size)
                noise = tensor2Var(noise)

                gen_data = self.gen(noise).detach()

                lab_data = self.vae.get_features(lab_data)
                gen_data = self.vae.get_features(gen_data)

                lab_logits = self.classifier(lab_data)
                gen_logits = self.classifier(gen_data)

                label_true = tensor2Var(torch.ones(lab_data.shape[0]))
                label_gen = tensor2Var(torch.zeros(gen_data.shape[0]))

                pred = torch.cat([lab_logits, gen_logits], dim=0)
                label = torch.cat([label_true, label_gen], dim=0)

                lab_loss = F.binary_cross_entropy(pred, label)

                self.classifier_opt.zero_grad()
                lab_loss.backward()
                self.classifier_opt.step()

            # if epoch % 10:
            #     print(pred.shape)

            if args.save_classifier:
                save_dict = {
                    'total_iter': total_iter,
                    'classifier_state_dict': self.classifier.state_dict(),
                    'classifier_opt': self.classifier_opt.state_dict()
                }
                torch.save(save_dict, self.classifier_checkpoint)

            self.logger.info('epoch: %d, iter: %d, spent: %.3f s' %
                             (epoch, total_iter, time() - begin_time))
            self.logger.info('[train] loss: %.4f' % (lab_loss.cpu().item()))

            self.logger.info('--------')

            begin_time = time()

            total_iter += 1
            self.total_iter += 1
    def visualize_embedding(self, p_d_2):
        self.gen.eval()
        self.dis.eval()
        self.vae.eval()

        vis_size = 200

        true_emb = []

        gen_emb = []

        cum_size = 0

        with torch.no_grad():
            for i, (data, _) in enumerate(p_d_2.get_iter()):
                data = tensor2Var(data)

                feat = self.vae.get_features(data)

                true_emb.append(feat.squeeze().cpu().numpy())

                gen_feat = self.gen(feat)

                gen_emb.append(gen_feat.squeeze().cpu().numpy())

                # gen_emb.append(gen_data.squeeze().cpu().numpy())

                cum_size += data.shape[0]

                if cum_size >= vis_size:
                    break

        true_emb = np.vstack(true_emb)
        gen_emb = np.vstack(gen_emb)

        # print(true_emb.shape, gen_emb.shape)

        tsne = sklearn.manifold.TSNE(2)

        all_emb = np.vstack([true_emb, gen_emb])

        # print(all_emb.shape)

        all_emb = tsne.fit_transform(all_emb)

        size = all_emb.shape[0] // 2

        true_emb = all_emb[:size]
        gen_emb = all_emb[size:]

        plt.clf()

        t = plt.scatter(true_emb[:, 0], true_emb[:, 1], label='true data')
        g = plt.scatter(gen_emb[:, 0], gen_emb[:, 1], label='gen data')

        plt.legend([t, g], ['true data', 'gen data'])

        save_path = os.path.join(self.args.log_folder,
                                 'embedding_%d.png' % self.total_iter)

        plt.savefig(save_path)

        self.vae.train()
    def train_ae(self, tr_data_dict):
        args = self.args

        train_loader = tr_data_dict['train_loader']

        ######################################################################
        ### start training

        total_iter = 0

        best_loss = 1e8
        best_err = 1e8
        best_err_per = 1e8
        begin_time = time()

        stop = 0

        for epoch in range(args.max_epochs):
            epoch_ratio = float(epoch) / float(args.max_epochs)

            self.vae.train()
            for i, (lab_data, _) in enumerate(train_loader.get_iter()):
                lab_data = tensor2Var(lab_data)

                reconstruct, mean, _ = self.vae(lab_data)

                # reconstruction_loss = F.binary_cross_entropy(reconstruct, data_v, size_average=True)
                reconstruction_loss = ((reconstruct - lab_data)**2).mean()
                feature_loss = ((mean - 0)**2).mean()

                loss = reconstruction_loss + self.args.LAMBDA * feature_loss

                self.vae_opt.zero_grad()
                loss.backward()
                self.vae_opt.step()

            if args.save_vae:
                save_dict = {
                    'total_iter': total_iter,
                    'vae_state_dict': self.vae.state_dict(),
                    'vae_opt': self.vae_opt.state_dict()
                }

                torch.save(save_dict, self.vae_checkpoint)

            self.logger.info('epoch: %d, iter: %d, spent: %.3f s' %
                             (epoch, total_iter, time() - begin_time))
            self.logger.info(
                '[train] loss: %.4f, reconst loss: %.4f, feature_loss: %.4f' %
                (loss.cpu().item(), reconstruction_loss.cpu().item(),
                 feature_loss.cpu().item()))

            self.logger.info('--------')

            if epoch % 10 == 0:
                self.visualize_reconstruction(train_loader, epoch)
                # self.visualize_generation(epoch)

            begin_time = time()

            total_iter += 1
            self.total_iter += 1
    def train(self, tr_data_dict):
        args = self.args

        train_loader = tr_data_dict['train_loader']

        p_d = tr_data_dict['p_d']
        p_d_bar = tr_data_dict['p_d_bar']
        p_d_2 = tr_data_dict['p_d_2']

        ######################################################################
        ### start training

        # if args.gan_checkpoint == "":
        #     self.param_init_cnn(p_d_2)

        total_iter = 0

        best_loss = 1e8
        best_err = 1e8
        best_err_per = 1e8
        begin_time = time()

        stop = 0

        for epoch in range(args.max_epochs):
            epoch_ratio = float(epoch) / float(args.max_epochs)

            self.dis.train()
            self.gen.train()

            self.vae.train()
            self.classifier.train()

            for i, (lab_data, _) in enumerate(train_loader.get_iter()):
                lab_data = tensor2Var(lab_data)

                # noise = create_noise(args.train_batch_size, args.noise_size)
                # noise = tensor2Var(noise)

                # gen_data = self.gen(noise).detach()

                # lab_feat = self.vae.get_features(lab_data)
                # # gen_feat = self.vae.get_features(gen_data)
                # gen_feat = p_d_bar.sample_feat(self.vae.get_features)

                # lab_logits = self.classifier(lab_feat)
                # gen_logits = self.classifier(gen_feat)

                # label_true = tensor2Var(torch.ones(lab_feat.shape[0]))
                # label_gen = tensor2Var(torch.zeros(gen_feat.shape[0]))

                # pred = torch.cat([lab_logits, gen_logits], dim=0)
                # label = torch.cat([label_true, label_gen], dim=0)

                # lab_loss = F.binary_cross_entropy(pred, label)

                # self.classifier_opt.zero_grad()
                # lab_loss.backward()
                # self.classifier_opt.step()

                reconstruct, mean, logvar = self.vae(lab_data)

                noise_in = tensor2Var(
                    torch.FloatTensor(mean.size()).uniform_(
                        -self.args.beta_1, self.args.beta_1))
                noise_out = tensor2Var(
                    torch.FloatTensor(mean.size()).uniform_(
                        -self.args.beta_2, self.args.beta_2))

                reconstruct = self.vae.decode(
                    torch.clamp(mean + noise_in, -1.0, 1.0))

                reconstruct_out = self.vae.decode(
                    torch.clamp(mean + noise_out, -1.0, 1.0))

                # reconstruct_gen, *_ = self.vae(gen_data)

                # reconstruction_loss = F.binary_cross_entropy(reconstruct, data_v, size_average=True)
                reconstruction_loss = ((reconstruct - lab_data)**2).mean()

                # gen_reconstruction_loss = torch.max(
                #     tensor2Var(torch.zeros(gen_data.shape[0])),
                #     1.0 - ((reconstruct_gen - gen_data) ** 2).view(gen_data.shape[0], -1).mean(1)).mean()

                gen_reconstruction_loss = torch.max(
                    tensor2Var(torch.zeros(lab_data.shape[0])),
                    0.1 - ((reconstruct_out - lab_data)**2).view(
                        lab_data.shape[0], -1).mean(1)).mean()

                # gen_reconstruction_loss = ((reconstruct_gen - gen_data) ** 2).mean()
                # gen_reconstruction_loss = 0*((reconstruct - lab_data) ** 2).mean()

                feature_loss = ((mean - 0)**2).mean()

                loss = reconstruction_loss + self.args.LAMBDA * feature_loss + gen_reconstruction_loss

                # reconstruction_loss = ((reconstruct - lab_data) ** 2).mean()
                # kl_div = (-0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp(), dim=1)).mean()

                # loss = reconstruction_loss + self.args.LAMBDA * kl_div + gen_reconstruction_loss

                self.vae_opt.zero_grad()
                loss.backward()
                self.vae_opt.step()

                ##################
                ## train the model

                # train all the networks
                # dis_dist = self.train_d(p_d, p_d_bar)

                # gen_critic = self.train_g()

            if args.save_vae:
                save_dict = {
                    'total_iter': total_iter,
                    'vae_state_dict': self.vae.state_dict(),
                    'vae_opt': self.vae_opt.state_dict()
                }

                torch.save(save_dict, self.vae_checkpoint)

                # save_dict = {'total_iter': total_iter,
                #         'classifier_state_dict': self.classifier.state_dict(),
                #         'classifier_opt': self.classifier_opt.state_dict()}
                # torch.save(save_dict, self.classifier_checkpoint)

            self.logger.info('epoch: %d, iter: %d, spent: %.3f s' %
                             (epoch, total_iter, time() - begin_time))
            self.logger.info(
                '[train] loss: %.4f, reconst loss: %.4f, feature_loss: %.4f, \
                gen_reconstruction_loss: %.4f' %
                (loss.cpu().item(), reconstruction_loss.cpu().item(),
                 feature_loss.cpu().item(),
                 gen_reconstruction_loss.cpu().item()))

            # self.logger.info('[train] loss: %.4f, reconst loss: %.4f, kl_div: %.4f, \
            #     gen_reconstruction_loss: %.4f' % (loss.cpu().item(),
            #     reconstruction_loss.cpu().item(), kl_div.cpu().item(), gen_reconstruction_loss.cpu().item()))

            # self.logger.info('[train] loss: %.4f, reconst loss: %.4f, feature_loss: %.4f' % (loss.cpu().item(),
            #     reconstruction_loss.cpu().item(), feature_loss.cpu().item()))

            # self.logger.info('%s: %.4f' % ('dis_dist', dis_dist))

            # self.logger.info('classifier loss: %.4f' % lab_loss.cpu().item())

            self.logger.info('--------')

            if epoch % 10 == 0:
                self.visualize_reconstruction(train_loader, epoch)
                # self.visualize()
                # self.visualize_embedding(p_d_2)
                # self.visualize_generation(epoch)

            begin_time = time()

            total_iter += 1
            self.total_iter += 1
示例#27
0
def test(args, vae_checkpoint, classifier_checkpoint, test_loader):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.use_gpu

    # save_root = 'output'

    # save_folder = os.path.join(save_root, model_id)

    # if not os.path.exists(save_folder):
    #     os.makedirs(save_folder)

    feature_size = args.feature_size

    if args.dataset == 'mnist':
        import mnist_model
        classifier = mnist_model.Classifier(args, feature_size)
        if args.feature_extractor == 'ae':
            vae = mnist_model.AE(args, feature_size)
        else:
            vae = mnist_model.VAE(args, feature_size)

    elif args.dataset == 'cifar':
        import cnn_model
        classifier = cnn_model.Classifier(args, feature_size)
        if args.feature_extractor == 'ae':
            vae = cnn_model.AE(args, feature_size)
        else:
            vae = cnn_model.VAE(args, feature_size)

    # checkpoint = torch.load(classifier_checkpoint, map_location=lambda storage, loc: storage)
    # classifier.load_state_dict(checkpoint['classifier_state_dict'])

    checkpoint = torch.load(vae_checkpoint,
                            map_location=lambda storage, loc: storage)
    vae.load_state_dict(checkpoint['vae_state_dict'])

    if torch.cuda.is_available():
        print('CUDA ensabled.')
        classifier.cuda()
        vae.cuda()

    for p in classifier.parameters():
        p.requires_grad = False

    for p in vae.parameters():
        p.requires_grad = False

    classifier.eval()
    vae.eval()

    x = []
    y = []

    emb = []

    data_list = []

    r_list = []

    label_list = []

    for i, (data, label) in enumerate(test_loader.get_iter(shuffle=False)):
        data = tensor2Var(data)

        # edge = tensor2Var(edge)

        feat = vae.get_features(data)
        # feat = vae.get_features(edge)

        emb.append(feat.squeeze().cpu().numpy())

        r, mean, _ = vae(data)

        # r, mean, _ = vae(edge)

        # noise_in = tensor2Var(torch.FloatTensor(mean.size()).uniform_(-args.beta_1, args.beta_1))
        # noise_out = tensor2Var(torch.FloatTensor(mean.size()).uniform_(-args.beta_2, args.beta_2))

        # r = vae.decode(torch.clamp(mean + noise_in, -1.0, 1.0))

        # r_out = vae.decode(torch.clamp(mean + noise_out, -1.0, 1.0))

        # pred = classifier(feat)

        data_list.append(data.cpu().numpy())
        r_list.append(r.cpu().numpy())
        label_list.append(label.cpu().numpy())

        # if i == 0:
        #     tv.utils.save_image(data * 0.5 + 0.5,
        #         os.path.join(args.log_folder, 'test_origin.png'), nrow=10)

        #     # tv.utils.save_image(edge * 0.5 + 0.5,
        #     #     os.path.join(args.log_folder, 'test_origin_edge.png'), nrow=10)

        #     tv.utils.save_image(r * 0.5 + 0.5,
        #         os.path.join(args.log_folder, 'test_reconstruct.png'), nrow=10)

        # tv.utils.save_image(r_out * 0.5 + 0.5,
        #     os.path.join(args.log_folder, 'test_reconstruct_out.png'), nrow=10)

        # pred = -((r - data) ** 2).reshape(data.shape[0], -1).mean(1)

        pred = -((r - data)**2).reshape(data.shape[0], -1).mean(1)

        x.append(pred.cpu().numpy())

        y.append(label.cpu().numpy())

    data_list = np.vstack(data_list)
    r_list = np.vstack(r_list)
    label_list = np.hstack(label_list)

    test_data = []
    test_r = []

    # print(label_list.shape)

    pos = (label_list == 0)
    # print(pos.shape)
    test_data.append(data_list[pos][:90])
    test_r.append(r_list[pos][:90])

    pos = (label_list == 1)
    # print(pos.shape)
    test_data.append(data_list[pos][:10])
    test_r.append(r_list[pos][:10])

    test_data = torch.from_numpy(np.vstack(test_data))
    test_r = torch.from_numpy(np.vstack(test_r))

    # print(test_data.shape)

    tv.utils.save_image(test_data * 0.5 + 0.5,
                        os.path.join(args.log_folder, 'test_origin.png'),
                        nrow=10)

    # tv.utils.save_image(edge * 0.5 + 0.5,
    #     os.path.join(args.log_folder, 'test_origin_edge.png'), nrow=10)

    tv.utils.save_image(test_r * 0.5 + 0.5,
                        os.path.join(args.log_folder, 'test_reconstruct.png'),
                        nrow=10)

    x = np.hstack(x)
    y = np.hstack(y)

    emb = np.vstack(emb)

    print(x)
    print(y)

    print((x < 0.9).sum())

    # index = np.argsort(x)

    # x = x[index]
    # y = y[index]

    # print((y==0).sum(), (y==1).sum())

    print(x.shape, y.shape)

    fpr, tpr, thresholds = sklearn.metrics.roc_curve(y, x, pos_label=1)

    auc = sklearn.metrics.auc(fpr, tpr)
    print(auc)

    with open(os.path.join(args.log_folder, 'auc.txt'), 'w') as f:
        f.write('%.4f\n' % auc)
    # print(np.sum(y))

    ##########################################################
    # plot
    #

    tsne = sklearn.manifold.TSNE(2)

    # print(all_emb.shape)

    emb = emb[:1000]
    y = y[:1000]

    emb = tsne.fit_transform(emb)

    plt.clf()

    l_list = []
    for i in range(2):
        pos = y == i

        l_list.append(plt.scatter(emb[pos, 0], emb[pos, 1]))

    plt.legend(l_list, ['novelty data', 'train data'])

    save_path = os.path.join(args.log_folder, 'test_embedding.png')

    plt.savefig(save_path)