Пример #1
0
def scatter(X, **kwargs):
    if type(X) == torch.Tensor:
        X = to_numpy(X)
    labels = kwargs.pop('labels', None)
    if labels is not None and type(labels) == torch.Tensor:
        labels = to_numpy(labels)
    ax = kwargs.pop('ax', plt.gca())
    if kwargs.pop('no_ticks', False):
        ax.set_xticks([])
        ax.set_yticks([])
    if labels is None:
        c = kwargs.pop('color', 'k')
        ec = kwargs.pop('edgecolor', [0.2, 0.2, 0.2])
        ax.scatter(X[:, 0], X[:, 1], color=c, edgecolor=ec, **kwargs)
    else:
        ulabels = np.sort(np.unique(labels))
        colors = kwargs.pop('colors',
                            cm.rainbow(np.linspace(0, 1, len(ulabels))))
        edgecolors = kwargs.pop('edgecolors', 0.6 * colors)
        for (l, c, ec) in zip(ulabels, colors, edgecolors):
            ax.scatter(X[labels == l, 0],
                       X[labels == l, 1],
                       color=c,
                       edgecolor=ec,
                       **kwargs)
        return ulabels, colors
Пример #2
0
 def plot_clustering(self, X, results):
     labels = results.get('labels')
     theta = results.get('theta')
     B = X.shape[0]
     K = theta.shape[1]
     nx, ny = 50, 50
     fig, axes = plt.subplots(min(B, 2),
                              max(B // 2, 1),
                              figsize=(2.5 * B if B > 1 else 10, 10))
     axes = [axes] if B == 1 else axes.flatten()
     for b, ax in enumerate(axes):
         ulabels, colors = scatter(X[b], labels=labels[b], ax=ax)
         for l, c in zip(ulabels, colors):
             Xbl = X[b][labels[b] == l]
             Z, x, y = meshgrid_around(Xbl, nx, ny, margin=0.1)
             ll = self.net.flow.log_prob(Z,
                                         context=theta[b,
                                                       l]).reshape(nx, ny)
             ax.contour(to_numpy(x),
                        to_numpy(y),
                        to_numpy(ll.exp()),
                        zorder=10,
                        alpha=0.5)
             ax.set_xticks([])
             ax.set_yticks([])
Пример #3
0
    def plot_step(self, X):
        if X.shape[0] > 1:
            raise ValueError('No support for visualization when B > 1')

        self.net.eval()
        with torch.no_grad():
            X = X.cuda()
            params, ll, logits = self.net(X)
            labels = (logits > 0.0).squeeze()

            fig, axes = plt.subplots(1, 3, figsize=(21, 7))
            img = make_grid(X[0][labels == 1])
            axes[0].imshow(to_numpy(img).transpose(1, 2, 0))
            axes[0].axis('off')
            axes[0].set_title('In cluster')

            z, _ = self.net.prior.sample(num_samples=X.shape[1],
                                         context=params,
                                         device=X.device)
            h = self.net.decoder(torch.cat([z, params], -1))
            x_gen, _ = self.net.likel.sample(context=h)
            img = make_grid(x_gen)
            axes[1].imshow(to_numpy(img).transpose(1, 2, 0))
            axes[1].axis('off')
            axes[1].set_title('Generated')

            img = make_grid(X[0][labels == 0])
            axes[2].imshow(to_numpy(img).transpose(1, 2, 0))
            axes[2].axis('off')
            axes[2].set_title('Not in cluster')

            plt.tight_layout()
    def evaluate_batch(self, output):
        y_pred = tensor_utils.to_numpy(output["y_pred"])
        y = tensor_utils.to_numpy(output["y"])

        metric_values = {}
        for name, metric in self.metrics.items():
            value = metric(y_pred, y)
            metric_values[name] = value

        return metric_values
Пример #5
0
    def cluster(self, X, max_iter=50, verbose=True, check=False):
        B, N = X.shape[0], X.shape[1]
        self.net.eval()

        with torch.no_grad():
            logits = self.net(X)
            mask = (logits > 0.0)
            done = mask.sum((1, 2)) == N
            labels = torch.zeros_like(logits).squeeze(-1).int()
            for i in range(1, max_iter):
                logits = self.net(X, mask=mask)
                ind = logits > 0.0
                labels[(ind * mask.bitwise_not()).squeeze(-1)] = i
                mask[ind] = True

                num_processed = mask.sum((1, 2))
                done = num_processed == N
                if verbose:
                    print(to_numpy(num_processed))
                if done.sum() == B:
                    break

        fail = done.sum() < B

        if check:
            return None, labels, torch.zeros(1), fail
        else:
            return None, labels, torch.zeros(1)
Пример #6
0
Файл: mmaf.py Проект: mlzxy/dac
 def plot_step(self, X):
     B = X.shape[0]
     self.net.eval()
     params, logits = self.net(X)
     labels = (logits > 0.0).int().squeeze(-1)
     nx, ny = 50, 50
     fig, axes = plt.subplots(2, B // 2, figsize=(7 * B / 5, 5))
     for b, ax in enumerate(axes.flatten()):
         scatter(X[b], labels=labels[b], ax=ax)
         Z, x, y = meshgrid_around(X[b], nx, ny, margin=0.1)
         ll = self.net.flow.log_prob(Z, context=params[b]).reshape(nx, ny)
         ax.contour(to_numpy(x),
                    to_numpy(y),
                    to_numpy(ll.exp()),
                    zorder=10,
                    alpha=0.3)
Пример #7
0
def draw_ellipse(pos, cov, ax=None, **kwargs):
    if type(pos) != np.ndarray:
        pos = to_numpy(pos)
    if type(cov) != np.ndarray:
        cov = to_numpy(cov)
    ax = ax or plt.gca()
    U, s, Vt = np.linalg.svd(cov)
    angle = np.degrees(np.arctan2(U[1, 0], U[0, 0]))
    width, height = 2 * np.sqrt(s)
    for nsig in range(1, 6):
        ax.add_patch(
            Ellipse(pos,
                    nsig * width,
                    nsig * height,
                    angle,
                    alpha=0.5 / nsig,
                    **kwargs))
Пример #8
0
 def plot_filtering(self, batch):
     X = batch['X']
     B, N = X.shape[0], X.shape[1]
     with torch.no_grad():
         outputs = self.compute_loss(batch, train=False)
     labels = (outputs['logits'] > 0.0).long()
     theta = outputs['theta']
     nx, ny = 50, 50
     fig, axes = plt.subplots(min(B, 2),
                              max(B // 2, 1),
                              figsize=(2.5 * B if B > 1 else 10, 10))
     axes = [axes] if B == 1 else axes.flatten()
     for b, ax in enumerate(axes):
         scatter(X[b], labels=labels[b], ax=ax)
         Z, x, y = meshgrid_around(X[b], nx, ny, margin=0.1)
         ll = self.net.flow.log_prob(Z, context=theta[b]).reshape(nx, ny)
         ax.contour(to_numpy(x),
                    to_numpy(y),
                    to_numpy(ll.exp()),
                    zorder=10,
                    alpha=0.3)
Пример #9
0
Файл: mmaf.py Проект: mlzxy/dac
 def plot_clustering(self, X, params, labels):
     B = X.shape[0]
     K = len(params)
     nx, ny = 50, 50
     if B > 1:
         fig, axes = plt.subplots(2, B // 2, figsize=(2.5 * B, 10))
         for b, ax in enumerate(axes.flatten()):
             ulabels, colors = scatter(X[b], labels=labels[b], ax=ax)
             for l, c in zip(ulabels, colors):
                 Xbl = X[b][labels[b] == l]
                 Z, x, y = meshgrid_around(Xbl, nx, ny, margin=0.1)
                 ll = self.net.flow.log_prob(Z,
                                             context=params[l][b]).reshape(
                                                 nx, ny)
                 ax.contour(to_numpy(x),
                            to_numpy(y),
                            to_numpy(ll.exp()),
                            zorder=10,
                            alpha=0.3)
     else:
         ulabels, colors = scatter(X[0], labels=labels[0])
         for l, c in zip(ulabels, colors):
             Xbl = X[0][labels[0] == l]
             Z, x, y = meshgrid_around(Xbl, nx, ny, margin=0.1)
             ll = self.net.flow.log_prob(Z, context=params[l][0]).reshape(
                 nx, ny)
             plt.contour(to_numpy(x),
                         to_numpy(y),
                         to_numpy(ll.exp()),
                         zorder=10,
                         alpha=0.3)
    def evaluate(self):
        with torch.no_grad():
            self.model = self.model.to(self.device)
            for first_input_ids, first_input_mask, second_input_ids, second_input_mask, y in self.data_loader:
                first_input_ids = first_input_ids.to(self.device)
                first_input_mask = first_input_mask.to(self.device)
                second_input_ids = second_input_ids.to(self.device)
                second_input_mask = second_input_mask.to(self.device)
                y = y.to(self.device)

                y_pred = self.model(first_input_ids, first_input_mask,
                                    second_input_ids, second_input_mask)

                y = tensor_utils.to_numpy(y)
                y_pred = tensor_utils.to_numpy(y_pred)
                for name, metric in self.metrics.items():
                    metric(y_pred, y)

            eval_metric_values = {
                name: metric.current_value()
                for name, metric in self.metrics.items()
            }
            return eval_metric_values
Пример #11
0
 def plot_clustering(self, X, results):
     X = self.combine_digits(X)[0]
     labels = results['labels'][0]
     ulabels = torch.unique(labels)
     K = len(ulabels)
     fig, axes = plt.subplots(1, K, figsize=(50, 50))
     for k, l in enumerate(ulabels):
         Xk = X[labels == l]
         Xk = Xk[:Xk.shape[0] - Xk.shape[0] % 4]
         I = to_numpy(make_grid(1 - Xk, nrow=4,
                                pad_value=0)).transpose(1, 2, 0)
         axes[k].set_title('cluster {}'.format(k + 1), fontsize=100)
         axes[k].imshow(I)
         axes[k].axis('off')
     plt.tight_layout()
Пример #12
0
    def plot_clustering(self, X, params, labels):
        if X.shape[0] > 1:
            raise ValueError('No support for visualization when B > 1')

        X = X[0]
        labels = labels[0]
        unique_labels = torch.unique(labels)

        fig, axes = plt.subplots(len(unique_labels), 1, figsize=(20, 20))
        for i, l in enumerate(unique_labels):
            Xl = X[labels == l]
            img = make_grid(Xl, nrow=10)
            axes[i].imshow(to_numpy(img).transpose(1, 2, 0))
            axes[i].axis('off')
        plt.tight_layout()
Пример #13
0
Файл: base.py Проект: mlzxy/dac
    def cluster(self, X, max_iter=50, verbose=True, check=False):
        B, N = X.shape[0], X.shape[1]
        self.net.eval()

        with torch.no_grad():
            params, ll, logits = self.net(X)
            params = [params]

            labels = torch.zeros_like(logits).squeeze(-1).int()
            mask = (logits > 0.0)
            done = mask.sum((1, 2)) == N
            for i in range(1, max_iter):
                params_, ll_, logits = self.net(X, mask=mask)

                ll = torch.cat([ll, ll_], -1)
                params.append(params_)

                ind = logits > 0.0
                labels[(ind * mask.bitwise_not()).squeeze(-1)] = i
                mask[ind] = True

                num_processed = mask.sum((1, 2))
                done = num_processed == N
                if verbose:
                    print(to_numpy(num_processed))
                if done.sum() == B:
                    break

            fail = done.sum() < B

            # ML estimate of mixing proportion pi
            pi = F.one_hot(labels.long(), len(params)).float()
            pi = pi.sum(1, keepdim=True) / pi.shape[1]
            ll = ll + (pi + 1e-10).log()
            ll = ll.logsumexp(-1).mean()

            if check:
                return params, labels, ll, fail
            else:
                return params, labels, ll
Пример #14
0
    def cluster(self, X, max_iter=50, verbose=True):
        B, N = X.shape[0], X.shape[1]
        self.net.eval()
        with torch.no_grad():
            anc_idxs = sample_anchors(B, N)
            outputs = self.net(X, anc_idxs)
            theta = [outputs.get('theta', None)]
            ll = [outputs.get('ll', None)]
            labels = torch.zeros_like(outputs['logits']).long()
            mask = outputs['logits'] > 0.0
            done = mask.sum(-1) == N
            for i in range(1, max_iter):
                anc_idxs = sample_anchors(B, N, mask=mask)
                outputs = self.net(X, anc_idxs, mask=mask)
                theta.append(outputs.get('theta', None))
                ll.append(outputs.get('ll', None))

                ind = outputs['logits'] > 0.0
                labels[ind*mask.bitwise_not()] = i
                mask[ind] = True

                num_processed = mask.sum(-1)
                done = num_processed == N
                if verbose:
                    print(to_numpy(num_processed))
                if done.sum() == B:
                    break

            if ll[0] is not None:
                ll = torch.stack(ll, -1)
                theta = torch.stack(theta, -2)

                pi = F.one_hot(labels, ll.shape[-1]).float()
                pi = pi.sum(1, keepdim=True) / pi.shape[1]
                ll = ll + (pi + 1e-10).log()
                ll = ll.logsumexp(-1).mean()

                return {'theta':theta, 'll':ll, 'labels':labels}
            else:
                return {'labels':labels}
Пример #15
0
from utils.plots import scatter, scatter_mog
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument('--benchmarkfile', type=str, default='mog_10_1000_4.tar')
parser.add_argument('--filename', type=str, default='test.log')

args, _ = parser.parse_known_args()
print(str(args))

benchmark = torch.load(os.path.join(benchmarks_path, args.benchmarkfile))
accm = Accumulator('ari', 'nmi', 'et')
for batch in tqdm(benchmark):
    B = batch['X'].shape[0]
    for b in range(B):
        X = to_numpy(batch['X'][b])
        true_labels = to_numpy(batch['labels'][b].argmax(-1))
        true_K = len(np.unique(true_labels))

        tick = time.time()
        spec = SpectralClustering(n_clusters=true_K,
                                  affinity='nearest_neighbors',
                                  n_neighbors=10).fit(X)
        labels = spec.labels_

        accm.update([
            ARI(true_labels, labels),
            NMI(true_labels, labels, average_method='arithmetic'),
            time.time() - tick
        ])
Пример #16
0
import time

parser = argparse.ArgumentParser()
parser.add_argument('--benchmarkfile', type=str, default='mog_10_1000_4.tar')
parser.add_argument('--k_max', type=int, default=6)
parser.add_argument('--filename', type=str, default='test.log')

args, _ = parser.parse_known_args()
print(str(args))

benchmark = torch.load(os.path.join(benchmarks_path, args.benchmarkfile))

vbmog = VBMOG(args.k_max)
accm = Accumulator('model ll', 'oracle ll', 'ARI', 'NMI', 'k-MAE', 'et')
for dataset in tqdm(benchmark):
    true_labels = to_numpy(dataset['labels'].argmax(-1))
    X = to_numpy(dataset['X'])
    ll = 0
    ari = 0
    nmi = 0
    mae = 0
    et = 0
    for b in range(len(X)):
        tick = time.time()
        vbmog.run(X[b], verbose=False)
        et += time.time() - tick
        ll += vbmog.loglikel(X[b])
        labels = vbmog.labels()
        ari += ARI(true_labels[b], labels)
        nmi += NMI(true_labels[b], labels, average_method='arithmetic')
        mae += abs(len(np.unique(true_labels[b])) - len(np.unique(labels)))
Пример #17
0
    def plot_filtering(self, batch):
        X = batch['X'].cuda()
        B, N, C, H, W = X.shape
        net = self.net
        net.eval()
        with torch.no_grad():
            outputs = net(X, return_z=True)
            theta = outputs['theta']
            theta_ = theta.repeat(1, N, 1).view(B * N, -1)
            labels = (outputs['logits'] > 0.0).long()

            # conditional generation
            z, _ = net.prior.sample(B * N, device='cuda', context=theta_)
            h_dec = net.dec(torch.cat([z, theta_], -1))
            gX, _ = net.likel.sample(context=h_dec)
            gX = gX.view(B, N, C, H, W)

            z = outputs['z']
            h_dec = net.dec(torch.cat([z, theta_], -1))
            rX, _ = net.likel.sample(context=h_dec)
            rX = rX.view(B, N, C, H, W)

        fig, axes = plt.subplots(1, 2, figsize=(40, 40))
        X = self.combine_digits(X)[0]
        labels = labels[0]

        X1 = X[labels == 1]
        X1 = X1[:X1.shape[0] - X1.shape[0] % 8]
        I = to_numpy(make_grid(1 - X1, nrow=8, pad_value=0)).transpose(1, 2, 0)
        axes[0].imshow(I)
        axes[0].set_title('Filtered out images', fontsize=60, pad=20)
        axes[0].axis('off')

        X0 = X[labels == 0]
        X0 = X0[:X0.shape[0] - X0.shape[0] % 8]
        I = to_numpy(make_grid(1 - X0, nrow=8, pad_value=0)).transpose(1, 2, 0)
        axes[1].imshow(I)
        axes[1].set_title('Remaining images', fontsize=60, pad=20)
        axes[1].axis('off')
        plt.tight_layout()
        #plt.savefig('figures/emnist_filtering.png', bbox_inches='tight')

        gX = self.combine_digits(gX)[0][:32]
        plt.figure()
        I = to_numpy(make_grid(1 - gX, nrow=8, pad_value=0)).transpose(1, 2, 0)
        plt.imshow(I)
        plt.title('Generated images', fontsize=15, pad=5)
        plt.axis('off')
        #plt.savefig('figures/emnist_gen.png', bbox_inches='tight')

        fig, axes = plt.subplots(1, 2, figsize=(40, 40))
        rX = self.combine_digits(rX)[0]
        X1 = rX[labels == 1]
        X1 = X1[:X1.shape[0] - X1.shape[0] % 8]
        I = to_numpy(make_grid(1 - X1, nrow=8, pad_value=0)).transpose(1, 2, 0)
        axes[0].imshow(I)
        axes[0].set_title('Reconstructions of filtered out images',
                          fontsize=60,
                          pad=20)
        axes[0].axis('off')

        X0 = rX[labels == 0]
        X0 = X0[:X0.shape[0] - X0.shape[0] % 8]
        I = to_numpy(make_grid(1 - X0, nrow=8, pad_value=0)).transpose(1, 2, 0)
        axes[1].imshow(I)
        axes[1].set_title('Reconstructions of remaining images',
                          fontsize=60,
                          pad=20)
        axes[1].axis('off')
        plt.tight_layout()
Пример #18
0
from data.kkanji import KKanji
from data.clustered_dataset import get_saved_cluster_loader
import torchvision.transforms as tvt

transform = tvt.Normalize(mean=[0.2170], std=[0.3787])
dataset = KKanji(os.path.join(datasets_path, 'kkanji'),
                 train=False,
                 transform=transform)
filename = os.path.join(benchmarks_path, 'kkanji_10_300_12.tar')
loader = get_saved_cluster_loader(dataset, filename, classes=range(700, 813))
accm = Accumulator('ari', 'k-mae')

for batch in tqdm(loader):
    B = batch['X'].shape[0]
    for b in range(B):
        X = to_numpy(batch['X'][b]).reshape(-1, 784)
        true_labels = to_numpy(batch['labels'][b].argmax(-1))
        true_K = len(np.unique(true_labels))

        # KMeans
        kmeans = KMeans(n_clusters=true_K).fit(X)
        labels = kmeans.labels_

        # Spectral
        #spec = SpectralClustering(n_clusters=true_K).fit(X)
        #labels = spec.labels_

        #gmm = GaussianMixture(n_components=true_K).fit(X)
        #labels = gmm.predict(X)

        accm.update(
Пример #19
0
net.load_state_dict(
    torch.load(os.path.join(save_dir, 'originalDAC_fullytrained.tar')))
net.eval()
test_loader = model.get_test_loader(filename=model.clusterfile)
accm = Accumulator('model ll', 'oracle ll', 'ARI', 'NMI', 'k-MAE')
num_failure = 0
logger = get_logger('{}_{}'.format(module_name, args.run_name),
                    os.path.join(save_dir, args.filename))
all_correct_counts = []
all_distances = []
for batch in tqdm(test_loader):
    params, labels, ll, fail = model.cluster(batch['X'].cuda(),
                                             max_iter=args.max_iter,
                                             verbose=False,
                                             check=True)
    true_labels = to_numpy(batch['labels'].argmax(-1))
    ari = 0
    nmi = 0
    mae = 0
    for b in range(len(labels)):
        labels_b = to_numpy(labels[b])
        ari += ARI(true_labels[b], labels_b)
        nmi += NMI(true_labels[b], labels_b, average_method='arithmetic')
        mae += abs(len(np.unique(true_labels[b])) - len(np.unique(labels_b)))
    ari /= len(labels)
    nmi /= len(labels)
    mae /= len(labels)

    oracle_ll = 0.0 if batch.get('ll') is None else batch['ll']
    accm.update([ll.item(), oracle_ll, ari, nmi, mae])
    num_failure += int(fail)