def scatter(X, **kwargs): if type(X) == torch.Tensor: X = to_numpy(X) labels = kwargs.pop('labels', None) if labels is not None and type(labels) == torch.Tensor: labels = to_numpy(labels) ax = kwargs.pop('ax', plt.gca()) if kwargs.pop('no_ticks', False): ax.set_xticks([]) ax.set_yticks([]) if labels is None: c = kwargs.pop('color', 'k') ec = kwargs.pop('edgecolor', [0.2, 0.2, 0.2]) ax.scatter(X[:, 0], X[:, 1], color=c, edgecolor=ec, **kwargs) else: ulabels = np.sort(np.unique(labels)) colors = kwargs.pop('colors', cm.rainbow(np.linspace(0, 1, len(ulabels)))) edgecolors = kwargs.pop('edgecolors', 0.6 * colors) for (l, c, ec) in zip(ulabels, colors, edgecolors): ax.scatter(X[labels == l, 0], X[labels == l, 1], color=c, edgecolor=ec, **kwargs) return ulabels, colors
def plot_clustering(self, X, results): labels = results.get('labels') theta = results.get('theta') B = X.shape[0] K = theta.shape[1] nx, ny = 50, 50 fig, axes = plt.subplots(min(B, 2), max(B // 2, 1), figsize=(2.5 * B if B > 1 else 10, 10)) axes = [axes] if B == 1 else axes.flatten() for b, ax in enumerate(axes): ulabels, colors = scatter(X[b], labels=labels[b], ax=ax) for l, c in zip(ulabels, colors): Xbl = X[b][labels[b] == l] Z, x, y = meshgrid_around(Xbl, nx, ny, margin=0.1) ll = self.net.flow.log_prob(Z, context=theta[b, l]).reshape(nx, ny) ax.contour(to_numpy(x), to_numpy(y), to_numpy(ll.exp()), zorder=10, alpha=0.5) ax.set_xticks([]) ax.set_yticks([])
def plot_step(self, X): if X.shape[0] > 1: raise ValueError('No support for visualization when B > 1') self.net.eval() with torch.no_grad(): X = X.cuda() params, ll, logits = self.net(X) labels = (logits > 0.0).squeeze() fig, axes = plt.subplots(1, 3, figsize=(21, 7)) img = make_grid(X[0][labels == 1]) axes[0].imshow(to_numpy(img).transpose(1, 2, 0)) axes[0].axis('off') axes[0].set_title('In cluster') z, _ = self.net.prior.sample(num_samples=X.shape[1], context=params, device=X.device) h = self.net.decoder(torch.cat([z, params], -1)) x_gen, _ = self.net.likel.sample(context=h) img = make_grid(x_gen) axes[1].imshow(to_numpy(img).transpose(1, 2, 0)) axes[1].axis('off') axes[1].set_title('Generated') img = make_grid(X[0][labels == 0]) axes[2].imshow(to_numpy(img).transpose(1, 2, 0)) axes[2].axis('off') axes[2].set_title('Not in cluster') plt.tight_layout()
def evaluate_batch(self, output): y_pred = tensor_utils.to_numpy(output["y_pred"]) y = tensor_utils.to_numpy(output["y"]) metric_values = {} for name, metric in self.metrics.items(): value = metric(y_pred, y) metric_values[name] = value return metric_values
def cluster(self, X, max_iter=50, verbose=True, check=False): B, N = X.shape[0], X.shape[1] self.net.eval() with torch.no_grad(): logits = self.net(X) mask = (logits > 0.0) done = mask.sum((1, 2)) == N labels = torch.zeros_like(logits).squeeze(-1).int() for i in range(1, max_iter): logits = self.net(X, mask=mask) ind = logits > 0.0 labels[(ind * mask.bitwise_not()).squeeze(-1)] = i mask[ind] = True num_processed = mask.sum((1, 2)) done = num_processed == N if verbose: print(to_numpy(num_processed)) if done.sum() == B: break fail = done.sum() < B if check: return None, labels, torch.zeros(1), fail else: return None, labels, torch.zeros(1)
def plot_step(self, X): B = X.shape[0] self.net.eval() params, logits = self.net(X) labels = (logits > 0.0).int().squeeze(-1) nx, ny = 50, 50 fig, axes = plt.subplots(2, B // 2, figsize=(7 * B / 5, 5)) for b, ax in enumerate(axes.flatten()): scatter(X[b], labels=labels[b], ax=ax) Z, x, y = meshgrid_around(X[b], nx, ny, margin=0.1) ll = self.net.flow.log_prob(Z, context=params[b]).reshape(nx, ny) ax.contour(to_numpy(x), to_numpy(y), to_numpy(ll.exp()), zorder=10, alpha=0.3)
def draw_ellipse(pos, cov, ax=None, **kwargs): if type(pos) != np.ndarray: pos = to_numpy(pos) if type(cov) != np.ndarray: cov = to_numpy(cov) ax = ax or plt.gca() U, s, Vt = np.linalg.svd(cov) angle = np.degrees(np.arctan2(U[1, 0], U[0, 0])) width, height = 2 * np.sqrt(s) for nsig in range(1, 6): ax.add_patch( Ellipse(pos, nsig * width, nsig * height, angle, alpha=0.5 / nsig, **kwargs))
def plot_filtering(self, batch): X = batch['X'] B, N = X.shape[0], X.shape[1] with torch.no_grad(): outputs = self.compute_loss(batch, train=False) labels = (outputs['logits'] > 0.0).long() theta = outputs['theta'] nx, ny = 50, 50 fig, axes = plt.subplots(min(B, 2), max(B // 2, 1), figsize=(2.5 * B if B > 1 else 10, 10)) axes = [axes] if B == 1 else axes.flatten() for b, ax in enumerate(axes): scatter(X[b], labels=labels[b], ax=ax) Z, x, y = meshgrid_around(X[b], nx, ny, margin=0.1) ll = self.net.flow.log_prob(Z, context=theta[b]).reshape(nx, ny) ax.contour(to_numpy(x), to_numpy(y), to_numpy(ll.exp()), zorder=10, alpha=0.3)
def plot_clustering(self, X, params, labels): B = X.shape[0] K = len(params) nx, ny = 50, 50 if B > 1: fig, axes = plt.subplots(2, B // 2, figsize=(2.5 * B, 10)) for b, ax in enumerate(axes.flatten()): ulabels, colors = scatter(X[b], labels=labels[b], ax=ax) for l, c in zip(ulabels, colors): Xbl = X[b][labels[b] == l] Z, x, y = meshgrid_around(Xbl, nx, ny, margin=0.1) ll = self.net.flow.log_prob(Z, context=params[l][b]).reshape( nx, ny) ax.contour(to_numpy(x), to_numpy(y), to_numpy(ll.exp()), zorder=10, alpha=0.3) else: ulabels, colors = scatter(X[0], labels=labels[0]) for l, c in zip(ulabels, colors): Xbl = X[0][labels[0] == l] Z, x, y = meshgrid_around(Xbl, nx, ny, margin=0.1) ll = self.net.flow.log_prob(Z, context=params[l][0]).reshape( nx, ny) plt.contour(to_numpy(x), to_numpy(y), to_numpy(ll.exp()), zorder=10, alpha=0.3)
def evaluate(self): with torch.no_grad(): self.model = self.model.to(self.device) for first_input_ids, first_input_mask, second_input_ids, second_input_mask, y in self.data_loader: first_input_ids = first_input_ids.to(self.device) first_input_mask = first_input_mask.to(self.device) second_input_ids = second_input_ids.to(self.device) second_input_mask = second_input_mask.to(self.device) y = y.to(self.device) y_pred = self.model(first_input_ids, first_input_mask, second_input_ids, second_input_mask) y = tensor_utils.to_numpy(y) y_pred = tensor_utils.to_numpy(y_pred) for name, metric in self.metrics.items(): metric(y_pred, y) eval_metric_values = { name: metric.current_value() for name, metric in self.metrics.items() } return eval_metric_values
def plot_clustering(self, X, results): X = self.combine_digits(X)[0] labels = results['labels'][0] ulabels = torch.unique(labels) K = len(ulabels) fig, axes = plt.subplots(1, K, figsize=(50, 50)) for k, l in enumerate(ulabels): Xk = X[labels == l] Xk = Xk[:Xk.shape[0] - Xk.shape[0] % 4] I = to_numpy(make_grid(1 - Xk, nrow=4, pad_value=0)).transpose(1, 2, 0) axes[k].set_title('cluster {}'.format(k + 1), fontsize=100) axes[k].imshow(I) axes[k].axis('off') plt.tight_layout()
def plot_clustering(self, X, params, labels): if X.shape[0] > 1: raise ValueError('No support for visualization when B > 1') X = X[0] labels = labels[0] unique_labels = torch.unique(labels) fig, axes = plt.subplots(len(unique_labels), 1, figsize=(20, 20)) for i, l in enumerate(unique_labels): Xl = X[labels == l] img = make_grid(Xl, nrow=10) axes[i].imshow(to_numpy(img).transpose(1, 2, 0)) axes[i].axis('off') plt.tight_layout()
def cluster(self, X, max_iter=50, verbose=True, check=False): B, N = X.shape[0], X.shape[1] self.net.eval() with torch.no_grad(): params, ll, logits = self.net(X) params = [params] labels = torch.zeros_like(logits).squeeze(-1).int() mask = (logits > 0.0) done = mask.sum((1, 2)) == N for i in range(1, max_iter): params_, ll_, logits = self.net(X, mask=mask) ll = torch.cat([ll, ll_], -1) params.append(params_) ind = logits > 0.0 labels[(ind * mask.bitwise_not()).squeeze(-1)] = i mask[ind] = True num_processed = mask.sum((1, 2)) done = num_processed == N if verbose: print(to_numpy(num_processed)) if done.sum() == B: break fail = done.sum() < B # ML estimate of mixing proportion pi pi = F.one_hot(labels.long(), len(params)).float() pi = pi.sum(1, keepdim=True) / pi.shape[1] ll = ll + (pi + 1e-10).log() ll = ll.logsumexp(-1).mean() if check: return params, labels, ll, fail else: return params, labels, ll
def cluster(self, X, max_iter=50, verbose=True): B, N = X.shape[0], X.shape[1] self.net.eval() with torch.no_grad(): anc_idxs = sample_anchors(B, N) outputs = self.net(X, anc_idxs) theta = [outputs.get('theta', None)] ll = [outputs.get('ll', None)] labels = torch.zeros_like(outputs['logits']).long() mask = outputs['logits'] > 0.0 done = mask.sum(-1) == N for i in range(1, max_iter): anc_idxs = sample_anchors(B, N, mask=mask) outputs = self.net(X, anc_idxs, mask=mask) theta.append(outputs.get('theta', None)) ll.append(outputs.get('ll', None)) ind = outputs['logits'] > 0.0 labels[ind*mask.bitwise_not()] = i mask[ind] = True num_processed = mask.sum(-1) done = num_processed == N if verbose: print(to_numpy(num_processed)) if done.sum() == B: break if ll[0] is not None: ll = torch.stack(ll, -1) theta = torch.stack(theta, -2) pi = F.one_hot(labels, ll.shape[-1]).float() pi = pi.sum(1, keepdim=True) / pi.shape[1] ll = ll + (pi + 1e-10).log() ll = ll.logsumexp(-1).mean() return {'theta':theta, 'll':ll, 'labels':labels} else: return {'labels':labels}
from utils.plots import scatter, scatter_mog import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument('--benchmarkfile', type=str, default='mog_10_1000_4.tar') parser.add_argument('--filename', type=str, default='test.log') args, _ = parser.parse_known_args() print(str(args)) benchmark = torch.load(os.path.join(benchmarks_path, args.benchmarkfile)) accm = Accumulator('ari', 'nmi', 'et') for batch in tqdm(benchmark): B = batch['X'].shape[0] for b in range(B): X = to_numpy(batch['X'][b]) true_labels = to_numpy(batch['labels'][b].argmax(-1)) true_K = len(np.unique(true_labels)) tick = time.time() spec = SpectralClustering(n_clusters=true_K, affinity='nearest_neighbors', n_neighbors=10).fit(X) labels = spec.labels_ accm.update([ ARI(true_labels, labels), NMI(true_labels, labels, average_method='arithmetic'), time.time() - tick ])
import time parser = argparse.ArgumentParser() parser.add_argument('--benchmarkfile', type=str, default='mog_10_1000_4.tar') parser.add_argument('--k_max', type=int, default=6) parser.add_argument('--filename', type=str, default='test.log') args, _ = parser.parse_known_args() print(str(args)) benchmark = torch.load(os.path.join(benchmarks_path, args.benchmarkfile)) vbmog = VBMOG(args.k_max) accm = Accumulator('model ll', 'oracle ll', 'ARI', 'NMI', 'k-MAE', 'et') for dataset in tqdm(benchmark): true_labels = to_numpy(dataset['labels'].argmax(-1)) X = to_numpy(dataset['X']) ll = 0 ari = 0 nmi = 0 mae = 0 et = 0 for b in range(len(X)): tick = time.time() vbmog.run(X[b], verbose=False) et += time.time() - tick ll += vbmog.loglikel(X[b]) labels = vbmog.labels() ari += ARI(true_labels[b], labels) nmi += NMI(true_labels[b], labels, average_method='arithmetic') mae += abs(len(np.unique(true_labels[b])) - len(np.unique(labels)))
def plot_filtering(self, batch): X = batch['X'].cuda() B, N, C, H, W = X.shape net = self.net net.eval() with torch.no_grad(): outputs = net(X, return_z=True) theta = outputs['theta'] theta_ = theta.repeat(1, N, 1).view(B * N, -1) labels = (outputs['logits'] > 0.0).long() # conditional generation z, _ = net.prior.sample(B * N, device='cuda', context=theta_) h_dec = net.dec(torch.cat([z, theta_], -1)) gX, _ = net.likel.sample(context=h_dec) gX = gX.view(B, N, C, H, W) z = outputs['z'] h_dec = net.dec(torch.cat([z, theta_], -1)) rX, _ = net.likel.sample(context=h_dec) rX = rX.view(B, N, C, H, W) fig, axes = plt.subplots(1, 2, figsize=(40, 40)) X = self.combine_digits(X)[0] labels = labels[0] X1 = X[labels == 1] X1 = X1[:X1.shape[0] - X1.shape[0] % 8] I = to_numpy(make_grid(1 - X1, nrow=8, pad_value=0)).transpose(1, 2, 0) axes[0].imshow(I) axes[0].set_title('Filtered out images', fontsize=60, pad=20) axes[0].axis('off') X0 = X[labels == 0] X0 = X0[:X0.shape[0] - X0.shape[0] % 8] I = to_numpy(make_grid(1 - X0, nrow=8, pad_value=0)).transpose(1, 2, 0) axes[1].imshow(I) axes[1].set_title('Remaining images', fontsize=60, pad=20) axes[1].axis('off') plt.tight_layout() #plt.savefig('figures/emnist_filtering.png', bbox_inches='tight') gX = self.combine_digits(gX)[0][:32] plt.figure() I = to_numpy(make_grid(1 - gX, nrow=8, pad_value=0)).transpose(1, 2, 0) plt.imshow(I) plt.title('Generated images', fontsize=15, pad=5) plt.axis('off') #plt.savefig('figures/emnist_gen.png', bbox_inches='tight') fig, axes = plt.subplots(1, 2, figsize=(40, 40)) rX = self.combine_digits(rX)[0] X1 = rX[labels == 1] X1 = X1[:X1.shape[0] - X1.shape[0] % 8] I = to_numpy(make_grid(1 - X1, nrow=8, pad_value=0)).transpose(1, 2, 0) axes[0].imshow(I) axes[0].set_title('Reconstructions of filtered out images', fontsize=60, pad=20) axes[0].axis('off') X0 = rX[labels == 0] X0 = X0[:X0.shape[0] - X0.shape[0] % 8] I = to_numpy(make_grid(1 - X0, nrow=8, pad_value=0)).transpose(1, 2, 0) axes[1].imshow(I) axes[1].set_title('Reconstructions of remaining images', fontsize=60, pad=20) axes[1].axis('off') plt.tight_layout()
from data.kkanji import KKanji from data.clustered_dataset import get_saved_cluster_loader import torchvision.transforms as tvt transform = tvt.Normalize(mean=[0.2170], std=[0.3787]) dataset = KKanji(os.path.join(datasets_path, 'kkanji'), train=False, transform=transform) filename = os.path.join(benchmarks_path, 'kkanji_10_300_12.tar') loader = get_saved_cluster_loader(dataset, filename, classes=range(700, 813)) accm = Accumulator('ari', 'k-mae') for batch in tqdm(loader): B = batch['X'].shape[0] for b in range(B): X = to_numpy(batch['X'][b]).reshape(-1, 784) true_labels = to_numpy(batch['labels'][b].argmax(-1)) true_K = len(np.unique(true_labels)) # KMeans kmeans = KMeans(n_clusters=true_K).fit(X) labels = kmeans.labels_ # Spectral #spec = SpectralClustering(n_clusters=true_K).fit(X) #labels = spec.labels_ #gmm = GaussianMixture(n_components=true_K).fit(X) #labels = gmm.predict(X) accm.update(
net.load_state_dict( torch.load(os.path.join(save_dir, 'originalDAC_fullytrained.tar'))) net.eval() test_loader = model.get_test_loader(filename=model.clusterfile) accm = Accumulator('model ll', 'oracle ll', 'ARI', 'NMI', 'k-MAE') num_failure = 0 logger = get_logger('{}_{}'.format(module_name, args.run_name), os.path.join(save_dir, args.filename)) all_correct_counts = [] all_distances = [] for batch in tqdm(test_loader): params, labels, ll, fail = model.cluster(batch['X'].cuda(), max_iter=args.max_iter, verbose=False, check=True) true_labels = to_numpy(batch['labels'].argmax(-1)) ari = 0 nmi = 0 mae = 0 for b in range(len(labels)): labels_b = to_numpy(labels[b]) ari += ARI(true_labels[b], labels_b) nmi += NMI(true_labels[b], labels_b, average_method='arithmetic') mae += abs(len(np.unique(true_labels[b])) - len(np.unique(labels_b))) ari /= len(labels) nmi /= len(labels) mae /= len(labels) oracle_ll = 0.0 if batch.get('ll') is None else batch['ll'] accm.update([ll.item(), oracle_ll, ari, nmi, mae]) num_failure += int(fail)