def visualize(inputs, outputs = None, size = 256): if isinstance(inputs, list) and len(inputs) == 2: inputs = inputs[0][-1] else: inputs = inputs[-1] if outputs is None: outputs = np.zeros(inputs.size()) images = [] for input, output in zip(inputs, outputs): input, output = to_np(input), to_np(output) input = resize_image(input, output.shape[-1], channel_first = True) image = input + output / 128. image = np.maximum(image, 0) image = np.minimum(image, 1) image = resize_image(image, size, channel_first = True) images.append(image) return images
def analyze_reprs(max_dims=16, threshold=.5, bound=8., step=.2): reprs_path = os.path.join('exp', args.exp, 'reprs') mkdir(reprs_path, clean=True) images_path = os.path.join(reprs_path, 'images') mkdir(images_path, clean=True) # statistics x, ym, yv = [], [], [] for k in range(means.shape[1]): x.extend([k, k]) ym.extend([np.min(means[:, k]), np.max(means[:, k])]) yv.extend([np.min(log_vars[:, k]), np.max(log_vars[:, k])]) plt.figure() plt.bar(x, ym, .5, color='b') plt.xlabel('dimension') plt.ylabel('mean') plt.savefig(os.path.join(images_path, 'means.png'), bbox_inches='tight') plt.figure() plt.bar(x, yv, .5, color='b') plt.xlabel('dimension') plt.ylabel('log(var)') plt.savefig(os.path.join(images_path, 'vars.png'), bbox_inches='tight') # dimensions values = np.arange(-bound, bound + step, step) magnitudes = np.max(np.abs(means), axis=0) indices = np.argsort(-magnitudes) dimensions = [k for k in indices if magnitudes[k] > threshold][:max_dims] print('==> dominated dimensions = {0}'.format(dimensions)) for split in ['train', 'test']: inputs, targets = iter(loaders[split]).next() inputs, targets = to_var(inputs, volatile=True), to_var(targets, volatile=True) outputs, z = model.forward(inputs, returns='z') for dim in tqdm(dimensions): repr = to_np(z).copy() samples = [] for val in tqdm(values, leave=False): repr[:, dim] = val sample = model.forward(inputs, z=to_var(repr, volatile=True)) samples.append(visualize(inputs, sample)) for k in range(args.batch): images = [sample[k] for sample in samples] image_path = os.path.join( images_path, '{0}-{1}-{2}.gif'.format(split, k, dim)) save_images(images, image_path, duration=.1, channel_first=True) inputs = visualize(inputs) for k in range(args.batch): image_path = os.path.join(images_path, '{0}-{1}.png'.format(split, k)) save_image(inputs[k], image_path, channel_first=True) # visualization with open(os.path.join(reprs_path, 'index.html'), 'w') as fp: print('<h3>statistics</h3>', file=fp) print('<img src="{0}">'.format(os.path.join('images', 'means.png')), file=fp) print('<img src="{0}">'.format(os.path.join('images', 'vars.png')), file=fp) print('<h3>inputs</h3>', file=fp) print('<table border="1" style="table-layout: fixed;">', file=fp) for split in ['train', 'test']: print('<tr>', file=fp) for k in range(args.batch): image_path = os.path.join('images', '{0}-{1}.png'.format(split, k)) print( '<td halign="center" style="word-wrap: break-word;" valign="top">', file=fp) print( '<img src="{0}" style="width:128px;">'.format(image_path), file=fp) print('</td>', file=fp) print('</tr>', file=fp) print('</table>', file=fp) for dim in dimensions: print('<h3>dimension [{0}]</h3>'.format(dim), file=fp) print('<table border="1" style="table-layout: fixed;">', file=fp) for split in ['train', 'test']: print('<tr>', file=fp) for k in range(args.batch): image_path = os.path.join( 'images', '{0}-{1}-{2}.gif'.format(split, k, dim)) print( '<td halign="center" style="word-wrap: break-word;" valign="top">', file=fp) print('<img src="{0}" style="width:128px;">'.format( image_path), file=fp) print('</td>', file=fp) print('</tr>', file=fp) print('</table>', file=fp)
def analyze_fmaps(size=256): fmaps_path = os.path.join('exp', args.exp, 'fmaps') mkdir(fmaps_path, clean=True) images_path = os.path.join(fmaps_path, 'images') mkdir(images_path, clean=True) # feature maps for split in ['train', 'test']: inputs, targets = iter(loaders[split]).next() inputs, targets = to_var(inputs, volatile=True), to_var(targets, volatile=True) outputs, features = model.forward(inputs, returns='features') num_scales, num_channels = len(features), features[0].size(1) for s in trange(num_scales): input, feature = inputs[0][-1], features[s] for b in trange(args.batch, leave=False): image = resize_image(to_np(input[b]), size=size, channel_first=True) for c in trange(num_channels, leave=False): fmap = resize_image(to_np(feature[b, c]), size=size, channel_first=True) if np.min(fmap) < np.max(fmap): fmap = (fmap - np.min(fmap)) / (np.max(fmap) - np.min(fmap)) image_path = os.path.join( images_path, '{0}-{1}-{2}-{3}.gif'.format(split, s, c, b)) save_images([image, fmap], image_path, channel_first=True) # visualization with open(os.path.join(fmaps_path, 'index.html'), 'w') as fp: for s in range(num_scales): for c in range(num_channels): print('<h3>scale [{0}] - channel [{1}]</h3>'.format( s + 1, c + 1), file=fp) print('<table border="1" style="table-layout: fixed;">', file=fp) for split in ['train', 'test']: print('<tr>', file=fp) for b in range(args.batch): image_path = os.path.join( 'images', '{0}-{1}-{2}-{3}.gif'.format(split, s, c, b)) print( '<td halign="center" style="word-wrap: break-word;" valign="top">', file=fp) print('<img src="{0}" style="width:128px;">'.format( image_path), file=fp) print('</td>', file=fp) print('</tr>', file=fp) print('</table>', file=fp)
# means & log_vars num_dists = 1024 means, log_vars = [], [] for inputs, targets in loaders['train']: inputs, targets = to_var(inputs, volatile=True), to_var(targets, volatile=True) # forward outputs, (mean, log_var) = model.forward(inputs, returns=['mean', 'log_var']) means.extend(to_np(mean)) log_vars.extend(to_np(log_var)) if len(means) >= num_dists and len(log_vars) >= num_dists: break means = np.array(means[:num_dists]) log_vars = np.array(log_vars[:num_dists]) # visualization num_samples = 4 for split in ['train', 'test']: inputs, targets = iter(loaders[split]).next() inputs, targets = to_var(inputs, volatile=True), to_var(targets,
def forward(self, features, labels=None): batch_size = features.size(0) start_tokens = torch.stack([to_var(hp.start_token)] * batch_size) # hidden & cell hidden, cell = torch.split(F.tanh(self.linear(features)), self.hidden_size, 1) hidden = torch.stack([hidden.contiguous()] * self.num_layers) cell = torch.stack([cell.contiguous()] * self.num_layers) if labels is not None: inputs = torch.cat([start_tokens.unsqueeze(1), labels], 1) inputs = torch.cat( [torch.stack([features] * (hp.max_length + 1), 1), inputs], 2) outputs, (hiddens, cells) = self.lstm.forward(inputs, (hidden, cell)) outputs = outputs.contiguous().view(-1, self.hidden_size) x = self.x_estimator.forward(outputs).view(-1, hp.max_length + 1, 2) c = self.c_estimator.forward(outputs).view(-1, hp.max_length + 1, 3) # output outputs = torch.cat([x, c], 2) else: outputs, hiddens, cells = [], [], [] output = start_tokens for k in range(hp.max_length + 1): input = torch.cat([features.unsqueeze(1), output.unsqueeze(1)], 2) output, (hidden, cell) = self.lstm.forward(input, (hidden, cell)) output = output.contiguous().squeeze(1) x = self.x_estimator.forward(output).view(-1, 2) c = self.c_estimator.forward(output).view(-1, 3) # sample c = to_np(c) indices = np.argmax(c, 1) c = np.zeros_like(c) for i, index in enumerate(indices): c[i, index] = 1 c = to_var(c) # output output = torch.cat([x, c], 1) # save outputs.append(output) hiddens.append(hidden.squeeze(0)) cells.append(cell.squeeze(0)) # stack outputs = torch.stack(outputs, 1) hiddens = torch.stack(hiddens, 1) cells = torch.stack(cells, 1) return outputs, (hiddens, cells)