def train(self): self.model.train() batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') meter_loss = AverageMeter('Loss', ':.4e') meter_loss_constr = AverageMeter('Constr', ':6.2f') meter_loss_perp = AverageMeter('Perplexity', ':6.2f') progress = ProgressMeter( self.training_loader.epoch_size()['__Video_0'], [ batch_time, data_time, meter_loss, meter_loss_constr, meter_loss_perp ], prefix="Steps: [{}]".format(self.num_steps)) data_iter = DALIGenericIterator(self.training_loader, ['data'], auto_reset=True) end = time.time() for i in range(self.start_steps, self.num_steps): # measure output loading time data_time.update(time.time() - end) try: images = next(data_iter)[0]['data'] except StopIteration: data_iter.reset() images = next(data_iter)[0]['data'] images = images.to('cuda') b, d, _, _, c = images.size() images = rearrange(images, 'b d h w c -> (b d) c h w') images = self.normalize(images.float() / 255.) # images = rearrange(images, '(b d) c h w -> b (d c) h w', b=b, d=d, c=c) self.optimizer.zero_grad() vq_loss, images_recon, perplexity = self.model(images) recon_error = F.mse_loss(images_recon, images) loss = recon_error + vq_loss loss.backward() self.optimizer.step() meter_loss_constr.update(recon_error.item(), 1) meter_loss_perp.update(perplexity.item(), 1) meter_loss.update(loss.item(), 1) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 20 == 0: progress.display(i) if i % 1000 == 0: print('saving ...') save_checkpoint( self.folder_name, { 'steps': i, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() }, 'checkpoint%s.pth.tar' % i) self.scheduler.step() images, images_recon = map( lambda t: rearrange( t, '(b d) c h w -> b d c h w', b=b, d=d, c=c), [images, images_recon]) images_orig, images_recs = train_visualize( unnormalize=self.unnormalize, images=images[0, :self.n_images_save], n_images=self.n_images_save, image_recs=images_recon[0, :self.n_images_save]) save_images(file_name=os.path.join(self.path_img_orig, f'image_{i}.png'), image=images_orig) save_images(file_name=os.path.join(self.path_img_recs, f'image_{i}.png'), image=images_recs) if self.run_wandb: logs = { 'iter': i, 'loss_recs': meter_loss_constr.val, 'loss': meter_loss.val, 'lr': self.scheduler.get_last_lr()[0] } self.run_wandb.log(logs) print('saving ...') save_checkpoint( self.folder_name, { 'steps': self.num_steps, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), }, 'checkpoint%s.pth.tar' % self.num_steps)
optim = Adam(model.parameters(), lr = LEARNING_RATE) # training loop for _ in range(NUM_BATCHES): for _ in range(GRADIENT_ACCUMULATE_EVERY): batch = next(dl) seq, coords, mask = batch.seqs, batch.crds, batch.msks b, l, _ = seq.shape # prepare mask, labels seq, coords, mask = seq.argmax(dim = -1).to(DEVICE), coords.to(DEVICE), mask.to(DEVICE).bool() coords = rearrange(coords, 'b (l c) d -> b l c d', l = l) discretized_distances = get_bucketed_distance_matrix(coords[:, :, 0], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX) # predict distogram = model(seq, mask = mask) distogram = rearrange(distogram, 'b i j c -> b c i j') # loss loss = F.cross_entropy( distogram, discretized_distances, ignore_index = IGNORE_INDEX )
def __init__(self, path: str, fields: Tuple[RawField, RawField, Field, Field, Field, RawField], max_samples: int = None, use_mask: bool = True, use_ground_truth=False, **kwargs): print(f"Using ground truth: {use_ground_truth}", flush=True) """Create a SignTranslationDataset given paths and fields. Arguments: path: Common prefix of paths to the data files for both languages. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. Remaining keyword arguments: Passed to the constructor of data.Dataset. """ if not isinstance(fields[0], (tuple, list)): fields = [("sequence", fields[0]), ("signer", fields[1]), ("sgn", fields[2]), ("gls", fields[3]), ("txt", fields[4]), ("use_mask", fields[5])] _t = time.time() samples = {} sample_count = 0 longest_sequence = 0 acc = 0 frames = 0 for annotation_file, annotation_type in path: tmp = load_dataset_file(annotation_file) for s in tmp: unlistify(s) if max_samples is not None and sample_count >= max_samples: # For debugging to make a smaller dataset break seq_id = s["name"] # Decompress and augment #if use_ground_truth: # Make one hot encoding TODO replace if use_ground_truth or annotation_type == "logits_max": feature_length = fields[2][-1].pad_token.shape[0] if isinstance(s['alignments']['pami0'], list): s['alignments']['pami0'] = s['alignments']['pami0'][0] labels = [ int(item) if feature_length > 1500 else ((int(item) + 2) // 3) for item in s['alignments']['pami0'].split(" ") ] chunk_length = 16 collapsed_labels = [] for start in range((len(labels) - chunk_length) + 1): label = int(labels[start + (chunk_length // 2)]) collapsed_labels.append(label) #collapsed_labels = [labels[0]] #for label, prev_label in zip(labels[1:], labels[:-1]): # if label != prev_label: # collapsed_labels.append(labels) labels = collapsed_labels _feat = torch.zeros((len(labels), feature_length), dtype=torch.float32) for index, label in enumerate(labels): _feat[index, label] = 1 if use_ground_truth: _features = _feat else: try: features = pickle.loads(lzma.decompress(s['sign'])) except TypeError: features = s['sign'].numpy() _features = load_augment(annotation_type, features) if len(_features.shape) == 3: _features = rearrange(_features, "bs t f -> (bs t) f") if annotation_type == "logits_max": frames += len(labels) assert _features.shape == _feat.shape, f"(logit) {_features.shape} =/= {_feat.shape} (gt)" acc += (_features * _feat).sum() s["sign"] = _features longest_sequence = max(_features.shape[0], longest_sequence) if seq_id in samples: if samples[seq_id]["sign"].shape[0] > s["sign"].shape[ 0]: # If there are less symbols then pad with zero filled feature_size = s["sign"].shape[1] difference = samples[seq_id]["sign"].shape[0] - s[ "sign"].shape[0] s["sign"] = torch.cat([ torch.zeros(math.ceil(difference / 2), feature_size), s["sign"], torch.zeros(math.floor(difference / 2), feature_size) ], axis=0) samples[seq_id]["loaded"] += 1 assert samples[seq_id]["name"] == s["name"] assert samples[seq_id]["signer"] == s["signer"] assert samples[seq_id]["gloss"] == s["gloss"] assert samples[seq_id]["text"] == s["text"] samples[seq_id]["sign"] = torch.cat( [samples[seq_id]["sign"], s["sign"]], axis=1) else: samples[seq_id] = { "name": s["name"], "signer": s["signer"], "gloss": s["gloss"], "text": s["text"], "sign": s["sign"], "loaded": 1, } sample_count += 1 if annotation_type == "logits_max": print(f"GT acc: {acc / frames}") print(f"Longest sequence: {longest_sequence}") print(f"Done loading samples in {time.time()-_t}s") examples = [] for s in samples: sample = samples[s] if sample['loaded'] == len(path): examples.append( data.Example.fromlist( [ sample["name"], sample["signer"], # This is for numerical stability sample["sign"] + 1e-8, str(sample["gloss"]).strip(), str(sample["text"]).strip(), use_mask, ], fields, )) else: print( f"{s} only loaded {sample['loaded']} annotations so has been removed" ) super().__init__(examples, fields, **kwargs)
def forward(self, sequence: torch.Tensor, sequence_lengths: Optional[torch.Tensor] = None, may_deactivate_seq: bool = True) -> torch.Tensor: """ Args: sequence (B, N, K, S): Chunked input sequence sequence_lengths (B): Sequence lengths along segment dimension (S) may_deactivate_seq: If set to `True`, the handling of sequence lengths is disabled when all examples in the batch have the same length """ # The handling of sequence lengths can be disabled if all examples in a # batch have the same length and this length matches the size of the # time axis of the input sequence (i.e., the signal is not 0-padded) # This speeds up the computations if may_deactivate_seq and sequence_lengths is not None and ( len(sequence_lengths) == 1 or all(sequence_lengths[1:] == sequence_lengths[:-1]) ) and sequence_lengths[0] == sequence.shape[-1]: sequence_lengths = None B, N, K, S = sequence.shape # LSTM only support 3-dim input. Reshape according to given shape lstm_in = rearrange(sequence, f'b n k s -> {self.lstm_reshape_to}') # Call lstm if sequence_lengths is not None: # TODO: don't hardcode this if 's' in self.lstm_reshape_to[:4]: packed = pack(rearrange(sequence, 'b n k s -> b s k n'), sequence_lengths) else: assert self.lstm_reshape_to[1] == 'b' packed_sequence_lengths = rearrange( sequence_lengths.reshape(B, 1, 1, 1).expand(B, 1, K, 1), f'b n k s -> {self.lstm_reshape_to}').squeeze() packed = pack_padded_sequence(lstm_in, packed_sequence_lengths, batch_first=True) else: packed_sequence_lengths = None packed = lstm_in out = self.rnn(packed) if isinstance(out, tuple): out = out[0] if sequence_lengths is not None and 's' not in self.lstm_reshape_to[:4]: out, _ = pad_packed_sequence(out, batch_first=True, total_length=S) # FC projection layer out = self.fc(out) # Apply norm and rearrange back to BxNxKxS if sequence_lengths is not None and 's' in self.lstm_reshape_to[:4]: out = self.norm(out) out = rearrange(unpack(out, sequence_lengths), 'b s k n -> b n k s') else: out = apply_examplewise(self.norm, out, packed_sequence_lengths) out = rearrange(out, f'{self.lstm_reshape_to} -> b n k s', b=B, s=S, n=self.feat_size, k=K) # Residual connection out = out + sequence return out
def forward(self, x, mask=None): b, n, _, h, img_size, axis, seq_len = *x.shape, self.heads, self.image_size, self.axis, self.seq_len softmax = torch.softmax img_seq_len = img_size**2 text_len = seq_len + 1 - img_seq_len # padding padding = seq_len - n + 1 mask = default(mask, lambda: torch.ones(b, text_len).bool()) x = F.pad(x, (0, 0, 0, padding), value=0) mask = mask[:, :text_len] # derive queries / keys / values qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), qkv) q = q * self.scale ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map( lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) # text attention dots_text = einsum('b i d, b j d -> b i j', q_text, k_text) mask_value = max_neg_value(self.fp16) i, j = dots_text.shape[-2:] text_causal_mask = torch.ones(i, j).triu_(j - i + 1).bool() dots_text.masked_fill_(text_causal_mask, mask_value) attn_text = softmax(dots_text, dim=-1) out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) # image attention split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c' merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d' # split out axis q_img, k_img, v_img = map( lambda t: rearrange(t, split_axis_einops, h=img_size), (q_img, k_img, v_img)) # similarity dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img) dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text) dots = torch.cat((dots_image_to_text, dots_image_to_image), dim=-1) # mask so image has full attention to text, but causal along axis bh, x, i, j = dots.shape causal_mask = torch.ones(i, img_size).triu_(img_size - i + 1).bool() causal_mask = repeat(causal_mask, 'i j -> b x i j', b=bh, x=x) mask = repeat(mask, 'b j -> (b h) x i j', h=h, x=x, i=i) mask = torch.cat((~mask, causal_mask), dim=-1) dots.masked_fill_(mask, mask_value) # attention. attn = softmax(dots, dim=-1) # aggregate attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[ ..., text_len:] out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img) out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text) out_image = out_image_to_image + out_image_to_text # merge back axis out_image = rearrange(out_image, merge_axis_einops, x=img_size) # combine attended values for both text and image out = torch.cat((out_text, out_image), dim=1) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) out = self.to_out(out) return out[:, :n]
def split_bl_embeddings(self, embeddings): R, L = self.img_size**2, self.nh_size**2 embeddings = rearrange(embeddings, '(b l) f -> b l f', l=L) return embeddings
def apply(self, x, *, patch_size, k, downscale, scorer_has_se, normalization_str="identity", selection_method, selection_method_kwargs=None, selection_method_inference=None, patch_dropout=0., hard_topk_probability=0., random_patch_probability=0., use_iterative_extraction, append_position_to_input, feature_network, aggregation_method, aggregation_method_kwargs=None, train): """Process a high resolution image by selecting a subset of useful patches. This model processes the input as follow: 1. Compute scores per patch on a downscaled version of the input. 2. Select "important" patches using sampling or top-k methods. 3. Extract the patches from the high-resolution image. 4. Compute representation vector for each patch with a feature network. 5. Aggregate the patch representation to obtain an image representation. Args: x: Input tensor of shape (batch, height, witdh, channels). patch_size: Size of the (squared) patches to extract. k: Number of patches to extract per image. downscale: Downscale multiplier for the input of the scorer network. scorer_has_se: Whether scorer network has Squeeze-excite layers. normalization_str: String specifying the normalization of the scores. selection_method: Method that selects which patches should be extracted, based on their scores. Either returns indices (hard selection) or indicators vectors (which could yield interpolated patches). selection_method_kwargs: Keyword args for the selection_method. selection_method_inference: Selection method used at inference. patch_dropout: Probability to replace a patch by 0 values. hard_topk_probability: Probability to use the true topk on the scores to select the patches. This operation has no gradient so scorer's weights won't be trained. random_patch_probability: Probability to replace each patch by a random patch in the image during training. use_iterative_extraction: If True, uses a for loop instead of patch indexing for memory efficiency. append_position_to_input: Append normalized (height, width) position to the channels of the input. feature_network: Network to be applied on each patch individually to obtain patch representation vectors. aggregation_method: Method to aggregate the representations of the k patches of each image to obtain the image representation. aggregation_method_kwargs: Keywords arguments for aggregation_method. train: If the model is being trained. Disable dropout otherwise. Returns: A representation vector for each image in the batch. """ selection_method = SelectionMethod(selection_method) aggregation_method = AggregationMethod(aggregation_method) if selection_method_inference: selection_method_inference = SelectionMethod( selection_method_inference) selection_method_kwargs = selection_method_kwargs or {} aggregation_method_kwargs = aggregation_method_kwargs or {} stats = {} # Compute new dimension of the scoring image. b, h, w, c = x.shape scoring_shape = (b, h // downscale, w // downscale, c) # === Compute the scores with a small CNN. if selection_method == SelectionMethod.RANDOM: scores_h, scores_w = Scorer.compute_output_size( h // downscale, w // downscale) num_patches = scores_h * scores_w else: # Downscale input to run scorer on. scoring_x = jax.image.resize(x, scoring_shape, method="bilinear") scores = Scorer(scoring_x, use_squeeze_excite=scorer_has_se, name="scorer") flatten_scores = einops.rearrange(scores, "b h w -> b (h w)") num_patches = flatten_scores.shape[-1] scores_h, scores_w = scores.shape[1:3] # Compute entropy before normalization prob_scores = jax.nn.softmax(flatten_scores) stats["entropy_before_normalization"] = jax.scipy.special.entr( prob_scores).sum(axis=1).mean(axis=0) # Normalize the flatten scores normalization_fn = create_normalization_fn(normalization_str) flatten_scores = normalization_fn(flatten_scores) scores = flatten_scores.reshape(scores.shape) stats["scores"] = scores[Ellipsis, None] # Concatenate height and width position to the input channels. if append_position_to_input: coords = utils.create_grid([h, w], value_range=(0., 1.)) x = jnp.concatenate( [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1) c += 2 # Overwrite the selection method at inference if selection_method_inference and not train: selection_method = selection_method_inference # === Patch selection # Select the patches by sampling or top-k. Some methods returns the indices # of the selected patches, other methods return indicator vectors. extract_by_indices = selection_method in [ SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM ] if selection_method is SelectionMethod.SINKHORN_TOPK: indicators = select_patches_sinkhorn_topk( flatten_scores, k=k, **selection_method_kwargs) elif selection_method is SelectionMethod.PERTURBED_TOPK: sigma = selection_method_kwargs["sigma"] num_samples = selection_method_kwargs["num_samples"] sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats["sigma"] = sigma indicators = select_patches_perturbed_topk(flatten_scores, k=k, sigma=sigma, num_samples=num_samples) elif selection_method is SelectionMethod.HARD_TOPK: indices = select_patches_hard_topk(flatten_scores, k=k) elif selection_method is SelectionMethod.RANDOM: batch_random_indices_fn = jax.vmap( functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False)) indices = batch_random_indices_fn( jax.random.split(nn.make_rng(), b)) # Compute scores entropy for regularization if selection_method not in [SelectionMethod.RANDOM]: prob_scores = flatten_scores # Normalize the scores if it is not already done. if "softmax" not in normalization_str: prob_scores = jax.nn.softmax(prob_scores) stats["entropy"] = jax.scipy.special.entr(prob_scores).sum( axis=1).mean(axis=0) # Randomly use hard topk at training. if (train and hard_topk_probability > 0 and selection_method not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]): true_indices = select_patches_hard_topk(flatten_scores, k=k) random_values = jax.random.uniform(nn.make_rng(), (b, )) use_hard = random_values < hard_topk_probability if extract_by_indices: indices = jnp.where(use_hard[:, None], true_indices, indices) else: true_indicators = make_indicators(true_indices, num_patches) indicators = jnp.where(use_hard[:, None, None], true_indicators, indicators) # Sample some random patches during training with random_patch_probability. if (train and random_patch_probability > 0 and selection_method is not SelectionMethod.RANDOM): single_random_patches = functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False) random_indices = jax.vmap(single_random_patches)(jax.random.split( nn.make_rng(), b)) random_values = jax.random.uniform(nn.make_rng(), (b, k)) use_random = random_values < random_patch_probability if extract_by_indices: indices = jnp.where(use_random, random_indices, indices) else: random_indicators = make_indicators(random_indices, num_patches) indicators = jnp.where(use_random[:, None, :], random_indicators, indicators) # === Patch extraction if extract_by_indices: patches = extract_patches_from_indices(x, indices, patch_size=patch_size, grid_shape=(scores_h, scores_w)) indicators = make_indicators(indices, num_patches) else: patches = extract_patches_from_indicators( x, indicators, patch_size, grid_shape=(scores_h, scores_w), iterative=use_iterative_extraction, patch_dropout=patch_dropout, train=train) chex.assert_shape(patches, (b, k, patch_size, patch_size, c)) stats["extracted_patches"] = einops.rearrange( patches, "b k i j c -> b i (k j) c") # Remove position channels for plotting. if append_position_to_input: stats["extracted_patches"] = ( stats["extracted_patches"][Ellipsis, :-2]) # === Compute patch features flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c") representations = feature_network(flatten_patches, train=train) if representations.ndim > 2: collapse_axis = tuple(range(1, representations.ndim - 1)) representations = representations.mean(axis=collapse_axis) representations = einops.rearrange(representations, "(b k) d -> b k d", k=k) stats["patch_representations"] = representations # === Aggregate the k patches # - for sampling we are forced to take an expectation # - for topk we have multiple options: mean, max, transformer. if aggregation_method is AggregationMethod.TRANSFORMER: patch_pos_encoding = nn.Dense(einops.rearrange( indicators, "b d k -> b k d"), features=representations.shape[-1]) chex.assert_equal_shape([representations, patch_pos_encoding]) representations += patch_pos_encoding representations = transformer.Transformer( representations, **aggregation_method_kwargs, is_training=train) elif aggregation_method is AggregationMethod.MEANPOOLING: representations = representations.mean(axis=1) elif aggregation_method is AggregationMethod.MAXPOOLING: representations = representations.max(axis=1) elif aggregation_method is AggregationMethod.SUM_LAYERNORM: representations = representations.sum(axis=1) representations = nn.LayerNorm(representations) representations = nn.Dense(representations, features=representations.shape[-1], name="classification_dense1") representations = nn.swish(representations) return representations, stats
def get_codebook_indices(self, img): b = img.shape[0] img = (2 * img) - 1 _, _, [_, _, indices] = self.model.encode(img) return rearrange(indices, '(b n) () -> b n', b=b)
def forward(self, x): shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip rand_flip_fn = lambda t: torch.flip(t, dims=(-1, )) flip_image_one, flip_image_two = rand_true(prob_flip), rand_true( prob_flip) flip_image_one_fn = rand_flip_fn if flip_image_one else identity flip_image_two_fn = rand_flip_fn if flip_image_two else identity cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range) cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range) image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode=self.cutout_interpolate_mode) image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode=self.cutout_interpolate_mode) image_one_cutout = flip_image_one_fn(image_one_cutout) image_two_cutout = flip_image_two_fn(image_two_cutout) image_one_cutout, image_two_cutout = self.augment1( image_one_cutout), self.augment2(image_two_cutout) proj_pixel_one, proj_instance_one = self.online_encoder( image_one_cutout) proj_pixel_two, proj_instance_two = self.online_encoder( image_two_cutout) image_h, image_w = shape[2:] proj_image_shape = proj_pixel_one.shape[2:] proj_image_h, proj_image_w = proj_image_shape coordinates = torch.meshgrid(torch.arange(image_h, device=device), torch.arange(image_w, device=device)) coordinates = torch.stack(coordinates).unsqueeze(0).float() coordinates /= math.sqrt(image_h**2 + image_w**2) coordinates[:, 0] *= proj_image_h coordinates[:, 1] *= proj_image_w proj_coors_one = cutout_and_resize( coordinates, cutout_coordinates_one, output_size=proj_image_shape, mode=self.coord_cutout_interpolate_mode) proj_coors_two = cutout_and_resize( coordinates, cutout_coordinates_two, output_size=proj_image_shape, mode=self.coord_cutout_interpolate_mode) proj_coors_one = flip_image_one_fn(proj_coors_one) proj_coors_two = flip_image_two_fn(proj_coors_two) proj_coors_one, proj_coors_two = map( lambda t: rearrange(t, 'b c h w -> (b h w) c'), (proj_coors_one, proj_coors_two)) pdist = nn.PairwiseDistance(p=2) num_pixels = proj_coors_one.shape[0] proj_coors_one_expanded = proj_coors_one[:, None].expand( num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2) proj_coors_two_expanded = proj_coors_two[None, :].expand( num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2) distance_matrix = pdist(proj_coors_one_expanded, proj_coors_two_expanded) distance_matrix = distance_matrix.reshape(num_pixels, num_pixels) positive_mask_one_two = distance_matrix < self.distance_thres positive_mask_two_one = positive_mask_one_two.t() with torch.no_grad(): target_encoder = self._get_target_encoder() target_proj_pixel_one, target_proj_instance_one = target_encoder( image_one_cutout) target_proj_pixel_two, target_proj_instance_two = target_encoder( image_two_cutout) # flatten all the pixel projections flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)') target_proj_pixel_one, target_proj_pixel_two = list( map(flatten, (target_proj_pixel_one, target_proj_pixel_two))) # get total number of positive pixel pairs positive_pixel_pairs = positive_mask_one_two.sum() # get instance level loss pred_instance_one = self.online_predictor(proj_instance_one) pred_instance_two = self.online_predictor(proj_instance_two) loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach()) loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach()) instance_loss = (loss_instance_one + loss_instance_two).mean() if positive_pixel_pairs == 0: return instance_loss, 0 if not self.use_pixpro: # calculate pix contrast loss proj_pixel_one, proj_pixel_two = list( map(flatten, (proj_pixel_one, proj_pixel_two))) similarity_one_two = F.cosine_similarity( proj_pixel_one[..., :, None], target_proj_pixel_two[..., None, :], dim=1) / self.similarity_temperature similarity_two_one = F.cosine_similarity( proj_pixel_two[..., :, None], target_proj_pixel_one[..., None, :], dim=1) / self.similarity_temperature loss_pix_one_two = -torch.log( similarity_one_two.masked_select( positive_mask_one_two[None, ...]).exp().sum() / similarity_one_two.exp().sum()) loss_pix_two_one = -torch.log( similarity_two_one.masked_select( positive_mask_two_one[None, ...]).exp().sum() / similarity_two_one.exp().sum()) pix_loss = (loss_pix_one_two + loss_pix_two_one) / 2 else: # calculate pix pro loss propagated_pixels_one = self.propagate_pixels(proj_pixel_one) propagated_pixels_two = self.propagate_pixels(proj_pixel_two) propagated_pixels_one, propagated_pixels_two = list( map(flatten, (propagated_pixels_one, propagated_pixels_two))) propagated_similarity_one_two = F.cosine_similarity( propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim=1) propagated_similarity_two_one = F.cosine_similarity( propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim=1) loss_pixpro_one_two = -propagated_similarity_one_two.masked_select( positive_mask_one_two[None, ...]).mean() loss_pixpro_two_one = -propagated_similarity_two_one.masked_select( positive_mask_two_one[None, ...]).mean() pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2 # total loss loss = pix_loss * self.alpha + instance_loss return loss, positive_pixel_pairs
def wiener_filter_predict(observation, desired, filter_order, return_w=False): """ Also known as projection of observation to desired (mir_eval.separation._project) w = argmin_w ( sum( |x * w - d|^2 ) ) return x * w >>> from paderbox.utils.pretty import pprint >>> x = np.array([1, 2, 3, 4, 5]) >>> y = np.array([1, 2, 1, 2, 1]) >>> from mir_eval.separation import _project >>> _project(x[None], y, 2) array([ 0.41754386, 0.78596491, 1.15438596, 1.52280702, 1.89122807, -0.24561404]) >>> wiener_filter_predict(x, y, 2) array([ 0.41754386, 0.78596491, 1.15438596, 1.52280702, 1.89122807, -0.24561404]) >>> wiener_filter_predict(np.array([x]), y, 2) array([ 0.41754386, 0.78596491, 1.15438596, 1.52280702, 1.89122807, -0.24561404]) >>> wiener_filter_predict(np.array([x, -x]), y, 2) array([ 0.41754386, 0.78596491, 1.15438596, 1.52280702, 1.89122807, -0.24561404]) >>> _project(np.array([x, -x]), y, 2) array([ 0.41754386, 0.78596491, 1.15438596, 1.52280702, 1.89122807, -0.24561404]) >>> pprint(wiener_filter_predict(np.array([x, y]), y, 2)) array([1., 2., 1., 2., 1., 0.]) >>> pprint(_project(np.array([x, y]), y, 2)) array([1., 2., 1., 2., 1., 0.]) """ n_fft = int(2**np.ceil( np.log2(observation.shape[-1] + desired.shape[-1] - 1.))) if observation.ndim == 1: observation = observation[None, :] Observation = np.fft.rfft(observation, n=n_fft, axis=-1) Desired = np.fft.rfft(desired, n=n_fft, axis=-1) Autocorr = np.einsum('KT,kT->KkT', Observation.conj(), Observation) Crosscorr = np.einsum('KT,T->KT', Observation.conj(), Desired) autocorr = np.fft.irfft(Autocorr) crosscorr = np.fft.irfft(Crosscorr) R = np.array([[scipy.linalg.toeplitz(a[:filter_order]) for a in aa] for aa in autocorr]) R = einops.rearrange( R, 'source1 source2 filter1 filter2 -> (source1 filter1) (source2 filter2)' ) p = crosscorr[..., :filter_order] p = einops.rearrange(p, 'source filter -> (source filter)') from paderbox.math.solve import stable_solve w = np.squeeze(stable_solve(R, p[..., None]), axis=-1) w = einops.rearrange(w, '(source filter) -> source filter', filter=filter_order) if return_w: return w else: return np.sum([ scipy.signal.fftconvolve(o, filter, axes=(-1)) for o, filter in zip(observation, w) ], axis=0)
def get_codebook_indices(self, img): img = map_pixels(img) z_logits = self.enc.blocks(img) z = torch.argmax(z_logits, dim=1) return rearrange(z, 'b h w -> b (h w)')
def forward(self, x): shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip rand_flip_fn = lambda t: torch.flip(t, dims=(-1, )) flip_image_one, flip_image_two = rand_true(prob_flip), rand_true( prob_flip) flip_image_one_fn = rand_flip_fn if flip_image_one else identity flip_image_two_fn = rand_flip_fn if flip_image_two else identity cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range) cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range) image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode=self.cutout_interpolate_mode) image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode=self.cutout_interpolate_mode) image_one_cutout = flip_image_one_fn(image_one_cutout) image_two_cutout = flip_image_two_fn(image_two_cutout) image_one_cutout, image_two_cutout = self.augment1( image_one_cutout), self.augment2(image_two_cutout) self.aug1 = image_one_cutout.detach().clone() self.aug2 = image_two_cutout.detach().clone() proj_pixel_one, proj_instance_one = self.online_encoder( image_one_cutout) proj_pixel_two, proj_instance_two = self.online_encoder( image_two_cutout) proj_pixel_one, proj_pixel_two = get_shared_region( proj_pixel_one, proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, image_one_cutout.shape, self.cutout_interpolate_mode) if proj_pixel_one is None or proj_pixel_two is None: positive_pixel_pairs = 0 else: positive_pixel_pairs = proj_pixel_one.shape[ -1] * proj_pixel_one.shape[-2] with torch.no_grad(): target_encoder = self._get_target_encoder() target_proj_pixel_one, target_proj_instance_one = target_encoder( image_one_cutout) target_proj_pixel_two, target_proj_instance_two = target_encoder( image_two_cutout) target_proj_pixel_one, target_proj_pixel_two = get_shared_region( target_proj_pixel_one, target_proj_pixel_two, cutout_coordinates_one, cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn, image_one_cutout.shape, self.cutout_interpolate_mode) # If max_latent_dim is specified, stochastically extract latents from the shared areas. b, c, pp_h, pp_w = proj_pixel_one.shape if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim: prob = torch.full((self.max_latent_dim, ), 1 / (self.max_latent_dim)) latents = [ proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two ] extracted = [] for l in latents: l = l.reshape(b, c, pp_h * pp_w) l = l[:, :, prob.multinomial(num_samples=self.max_latent_dim, replacement=False)] # For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square". # Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards # to the original image structure. sqdim = int(math.sqrt(self.max_latent_dim)) extracted.append(l.reshape(b, c, sqdim, sqdim)) proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted # flatten all the pixel projections flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)') target_proj_pixel_one, target_proj_pixel_two = list( map(flatten, (target_proj_pixel_one, target_proj_pixel_two))) # get instance level loss pred_instance_one = self.online_predictor(proj_instance_one) pred_instance_two = self.online_predictor(proj_instance_two) loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach()) loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach()) instance_loss = (loss_instance_one + loss_instance_two).mean() if positive_pixel_pairs == 0: return instance_loss, 0 # calculate pix pro loss propagated_pixels_one = self.propagate_pixels(proj_pixel_one) propagated_pixels_two = self.propagate_pixels(proj_pixel_two) propagated_pixels_one, propagated_pixels_two = list( map(flatten, (propagated_pixels_one, propagated_pixels_two))) propagated_similarity_one_two = F.cosine_similarity( propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim=1) propagated_similarity_two_one = F.cosine_similarity( propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim=1) loss_pixpro_one_two = -propagated_similarity_one_two.mean() loss_pixpro_two_one = -propagated_similarity_two_one.mean() pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2 return instance_loss, pix_loss, positive_pixel_pairs
def fwd_classification(self, x, seq_len=None): x = rearrange(x, 'b f t -> b t f') x = self._rnn_fwd(x, seq_len) x = rearrange(x, 'b t f -> b f t') y, seq_len_y = self._clf_fwd(x, seq_len) return nn.Sigmoid()(y), seq_len_y
shuffle=True, num_workers=4) # training min_running_loss = np.inf for epoch in range(EPOCHS): running_loss = 0.0 for i, batch in tqdm(enumerate(dataloader)): # zero the parameters gradient optimizer.zero_grad() # forward pass inputs = batch['image'].to(device) with torch.no_grad(): targets = rearrange(resnet18(inputs), 'b vec h w -> b (vec h w)') # h=w=1 #targets = torch.squeeze(resnet18(inputs)) outputs = teacher(inputs) loss = distillation_loss(outputs, targets) + compactness_loss(outputs) # backward pass loss.backward() optimizer.step() running_loss += loss.item() # print stats print(f"Epoch {epoch+1}, iter {i+1} \t loss: {running_loss}") if running_loss < min_running_loss: print(f"Loss decreased: {min_running_loss} -> {running_loss}.")
def forward(self, feats, coors, edges=None, mask=None, adj_mat=None): b, n, d, device, fourier_features, num_nearest, valid_radius, only_sparse_neighbors = *feats.shape, feats.device, self.fourier_features, self.num_nearest_neighbors, self.valid_radius, self.only_sparse_neighbors if exists(mask): num_nodes = mask.sum(dim=-1) use_nearest = num_nearest > 0 or only_sparse_neighbors rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange( coors, 'b j d -> b () j d') rel_dist = (rel_coors**2).sum(dim=-1, keepdim=True) i = j = n if use_nearest: ranking = rel_dist[..., 0] if exists(mask): rank_mask = mask[:, None, :] * mask[:, None, :] ranking.masked_fill_(~rank_mask, 1e5) if exists(adj_mat): if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b=b) if only_sparse_neighbors: num_nearest = int(adj_mat.float().sum(dim=-1).max().item()) valid_radius = 0 self_mask = rearrange( torch.eye(n, device=device, dtype=torch.bool), 'i j -> () i j') adj_mat = adj_mat.masked_fill(self_mask, False) ranking.masked_fill_(self_mask, -1.) ranking.masked_fill_(adj_mat, 0.) nbhd_ranking, nbhd_indices = ranking.topk(num_nearest, dim=-1, largest=False) nbhd_mask = nbhd_ranking <= valid_radius rel_coors = batched_index_select(rel_coors, nbhd_indices, dim=2) rel_dist = batched_index_select(rel_dist, nbhd_indices, dim=2) if exists(edges): edges = batched_index_select(edges, nbhd_indices, dim=2) j = num_nearest if fourier_features > 0: rel_dist = fourier_encode_dist(rel_dist, num_encodings=fourier_features) rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d') if use_nearest: feats_j = batched_index_select(feats, nbhd_indices, dim=1) else: feats_j = rearrange(feats, 'b j d -> b () j d') feats_i = rearrange(feats, 'b i d -> b i () d') feats_i, feats_j = broadcast_tensors(feats_i, feats_j) edge_input = torch.cat((feats_i, feats_j, rel_dist), dim=-1) if exists(edges): edge_input = torch.cat((edge_input, edges), dim=-1) m_ij = self.edge_mlp(edge_input) if exists(mask): mask_i = rearrange(mask, 'b i -> b i ()') if use_nearest: mask_j = batched_index_select(mask, nbhd_indices, dim=1) mask = (mask_i * mask_j) & nbhd_mask else: mask_j = rearrange(mask, 'b j -> b () j') mask = mask_i * mask_j if exists(self.coors_mlp): coor_weights = self.coors_mlp(m_ij) coor_weights = rearrange(coor_weights, 'b i j () -> b i j') rel_coors = self.coors_norm(rel_coors) if exists(mask): coor_weights.masked_fill_(~mask, 0.) coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors) + coors else: coors_out = coors if exists(self.node_mlp): m_ij_mask = rearrange(mask, '... -> ... ()') m_ij = m_ij.masked_fill(~m_ij_mask, 0.) if self.m_pool_method == 'mean': if exists(mask): # masked mean mask_sum = m_ij_mask.sum(dim=-2) m_i = safe_div(m_ij.sum(dim=-2), mask_sum) else: m_i = m_ij.mean(dim=-2) elif self.m_pool_method == 'sum': m_i = m_ij.sum(dim=-2) normed_feats = self.node_norm(feats) node_mlp_input = torch.cat((normed_feats, m_i), dim=-1) node_out = self.node_mlp(node_mlp_input) + feats else: node_out = feats return node_out, coors_out
def forward(self, x, B, T, W): num_spatial_tokens = (x.size(1) - 1) // T H = num_spatial_tokens // W if self.attention_type in ['space_only', 'joint_space_time']: x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x elif self.attention_type == 'divided_space_time': ## Temporal xt = x[:, 1:, :] xt = rearrange(xt, 'b (h w t) m -> (b h w) t m', b=B, h=H, w=W, t=T) res_temporal = self.drop_path( self.temporal_attn(self.temporal_norm1(xt))) res_temporal = rearrange(res_temporal, '(b h w) t m -> b (h w t) m', b=B, h=H, w=W, t=T) res_temporal = self.temporal_fc(res_temporal) xt = x[:, 1:, :] + res_temporal ## Spatial init_cls_token = x[:, 0, :].unsqueeze(1) cls_token = init_cls_token.repeat(1, T, 1) cls_token = rearrange(cls_token, 'b t m -> (b t) m', b=B, t=T).unsqueeze(1) xs = xt xs = rearrange(xs, 'b (h w t) m -> (b t) (h w) m', b=B, h=H, w=W, t=T) xs = torch.cat((cls_token, xs), 1) res_spatial = self.drop_path(self.attn(self.norm1(xs))) ### Taking care of CLS token cls_token = res_spatial[:, 0, :] cls_token = rearrange(cls_token, '(b t) m -> b t m', b=B, t=T) cls_token = torch.mean(cls_token, 1, True) ## averaging for every frame res_spatial = res_spatial[:, 1:, :] res_spatial = rearrange(res_spatial, '(b t) (h w) m -> b (h w t) m', b=B, h=H, w=W, t=T) res = res_spatial x = xt ## Mlp x = torch.cat((init_cls_token, x), 1) + torch.cat( (cls_token, res), 1) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
def form_input_patches(self, patches): patches = rearrange(patches, 'n b l c h w -> (n b l) c h w') return patches
def rs(array): h = int(np.sqrt(len(array))) return np.around( rearrange(array, '(h w) -> h w', h=h).cpu().numpy(), 3)
def extract_patches_from_indicators(x, indicators, patch_size, patch_dropout, grid_shape, train, iterative=False): """Extract patches from a batch of images. Args: x: The batch of images of shape (batch, height, width, channels). indicators: The one hot indicators of shape (batch, num_patches, k). patch_size: The size of the (squared) patches to extract. patch_dropout: Probability to replace a patch by 0 values. grid_shape: Pair of height, width of the disposition of the num_patches patches. train: If the model is being trained. Disable dropout if not. iterative: If True, etracts the patches with a for loop rather than instanciating the "all patches" tensor and extracting by dotproduct with indicators. `iterative` is more memory efficient. Returns: The patches extracted from x with shape (batch, k, patch_size, patch_size, channels). """ batch_size, height, width, channels = x.shape scores_h, scores_w = grid_shape k = indicators.shape[-1] indicators = einops.rearrange(indicators, "b (h w) k -> b k h w", h=scores_h, w=scores_w) scale_height = height // scores_h scale_width = width // scores_w padded_height = scale_height * scores_h + patch_size - 1 padded_width = scale_width * scores_w + patch_size - 1 top_pad = (patch_size - scale_height) // 2 left_pad = (patch_size - scale_width) // 2 bottom_pad = padded_height - top_pad - height right_pad = padded_width - left_pad - width # TODO(jbcdnr): assert padding is positive. padded_x = jnp.pad(x, [(0, 0), (top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]) # Extract the patches. Iterative fits better in memory as it does not # instanciate the "all patches" tensor but iterate over them to compute the # weighted sum with the indicator variables from topk. if not iterative: assert patch_dropout == 0., "Patch dropout not implemented." patches = utils.extract_images_patches(padded_x, window_size=(patch_size, patch_size), stride=(scale_height, scale_width)) shape = (batch_size, scores_h, scores_w, patch_size, patch_size, channels) chex.assert_shape(patches, shape) patches = jnp.einsum("b k h w, b h w i j c -> b k i j c", indicators, patches) else: mask = jnp.ones((batch_size, scores_h, scores_w)) mask = nn.dropout(mask, patch_dropout, deterministic=not train) def accumulate_patches(acc, index_i_j): i, j = index_i_j patch = jax.lax.dynamic_slice( padded_x, (0, i * scale_height, j * scale_width, 0), (batch_size, patch_size, patch_size, channels)) weights = indicators[:, :, i, j] is_masked = mask[:, i, j] weighted_patch = jnp.einsum("b, bk, bijc -> bkijc", is_masked, weights, patch) chex.assert_equal_shape([acc, weighted_patch]) acc += weighted_patch return acc, None indices = jnp.stack(jnp.meshgrid(jnp.arange(scores_h), jnp.arange(scores_w), indexing="ij"), axis=-1) indices = indices.reshape((-1, 2)) init_patches = jnp.zeros( (batch_size, k, patch_size, patch_size, channels)) patches, _ = jax.lax.scan(accumulate_patches, init_patches, indices) return patches
def rs(array): h = int(np.sqrt(len(array))) return rearrange(array, '(h w) -> h w', h=h)
def segment( signal: torch.Tensor, hop_size: int, window_size: int, sequence_lengths: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Zero-pads and segments the input sequence `signal` along the time dimension `L` (-2). Examples: >>> import torch >>> hop_size = 10 >>> segmented, _ = segment(torch.randn(1, 50, 3), hop_size, 2 * hop_size) # Shape is BxNxKxS (batch x feat x win x frames) >>> segmented.shape torch.Size([1, 3, 20, 6]) # The first block is zero-padded with hop_size >>> segmented[..., :hop_size, 0] tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]) # Last block as well >>> segmented[..., -hop_size:, -1] tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]) # Sequence lengths are computed >>> segmented, sequence_lengths = segment(torch.cat([torch.randn(1, 30, 3), torch.zeros(1, 10, 3)], dim=1), ... hop_size, 2*hop_size, torch.tensor([30])) >>> sequence_lengths tensor([4]) # All data outside of sequence_lengths is zero >>> segmented[0, ..., sequence_lengths[0]:].flatten() tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) # And the last segment within seuqence_lengths contains data, but zero padded at the end # (Conversion to uint8 is to make the doctest compatible with all PyTorch versions) >>> (segmented[0, ..., sequence_lengths[0] - 1] == 0).type(torch.uint8) tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.uint8) # Test the corner-cases for computation of sequence lengths # One above exact match >>> segment(1 + torch.arange(5)[None, :, None], 2, 4, torch.tensor(5)) (tensor([[[[0, 1, 3, 5], [0, 2, 4, 0], [1, 3, 5, 0], [2, 4, 0, 0]]]]), tensor(4)) # Exact match >>> segment(1 + torch.arange(5)[None, :, None], 2, 4, torch.tensor(4)) (tensor([[[[0, 1, 3, 5], [0, 2, 4, 0], [1, 3, 5, 0], [2, 4, 0, 0]]]]), tensor(3)) >>> segment(1 + torch.arange(4)[None, :, None], 2, 4, torch.tensor(4)) (tensor([[[[0, 1, 3], [0, 2, 4], [1, 3, 0], [2, 4, 0]]]]), tensor(3)) # One below exact match >>> segment(1 + torch.arange(5)[None, :, None], 2, 4, torch.tensor(3)) (tensor([[[[0, 1, 3, 5], [0, 2, 4, 0], [1, 3, 5, 0], [2, 4, 0, 0]]]]), tensor(3)) >>> segment(1 + torch.arange(3)[None, :, None], 2, 4, torch.tensor(3)) (tensor([[[[0, 1, 3], [0, 2, 0], [1, 3, 0], [2, 0, 0]]]]), tensor(3)) # Shift != size // 2 >>> segmented, seq_len = segment(torch.arange(5)[None, :, None], 3, 4, torch.tensor(5)) >>> segmented.shape torch.Size([1, 1, 4, 2]) >>> seq_len tensor(2) >>> segmented, seq_len = segment(torch.arange(5)[None, :, None], 1, 4, torch.tensor(5)) >>> segmented.shape torch.Size([1, 1, 4, 8]) >>> seq_len tensor(8) >>> segmented, seq_len = segment(torch.ones(1, 7912, 64), 50, 100, torch.tensor([7912])) >>> segmented.shape torch.Size([1, 64, 100, 160]) >>> seq_len tensor([160]) Args: signal ([Bx]LxN): 2D input signal with optional batch dimension hop_size: Hop size P window_size: Window size K sequence_lengths: These are not used for segmentation, but if provided, the resulting sequence lengths along the segment (S) dimension are returned in addition to the segmented signal. Then, the sequence length is the number of blocks that contain any part of the signal, and these might be 0-padded. Returns: [Bx]NxKxS S is the number of frames, K is the window size, N is the feature size """ # Add padding for the first and last blocks. Should be each hop_size so # that the first half of the first block and the last half of the last # block are filled with 0s for the case of 50% overlap. padding = window_size - hop_size signal = F.pad(signal, [0, 0, padding, padding]) segmented = pb.array.segment_axis(signal, window_size, hop_size, axis=-2, end='pad') segmented = rearrange(segmented, '... s k n -> ... n k s') if sequence_lengths is not None: sequence_lengths = sequence_lengths + 2 * padding sequence_lengths = (sequence_lengths - padding) sequence_lengths = (sequence_lengths - 1) // hop_size + 1 return segmented, sequence_lengths
def __init__( self, *, dim, vae, num_text_tokens = 10000, text_seq_len = 256, depth, heads = 8, dim_head = 64, reversible = False, attn_dropout = 0., ff_dropout = 0, sparse_attn = False, attn_types = None, loss_img_weight = 7, stable = False ): super().__init__() assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024)), 'vae must be an instance of DiscreteVAE' image_size = vae.image_size num_image_tokens = vae.num_tokens image_fmap_size = (vae.image_size // (2 ** vae.num_layers)) image_seq_len = image_fmap_size ** 2 num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len) self.text_emb = nn.Embedding(num_text_tokens, dim) self.image_emb = nn.Embedding(num_image_tokens, dim) self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) # +1 for <bos> self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss self.num_image_tokens = num_image_tokens self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len seq_len = text_seq_len + image_seq_len total_tokens = num_text_tokens + num_image_tokens self.total_tokens = total_tokens self.total_seq_len = seq_len self.vae = vae set_requires_grad(self.vae, False) # freeze VAE from being trained self.transformer = Transformer( dim = dim, causal = True, seq_len = seq_len, depth = depth, heads = heads, dim_head = dim_head, reversible = reversible, attn_dropout = attn_dropout, ff_dropout = ff_dropout, attn_types = attn_types, image_fmap_size = image_fmap_size, sparse_attn = sparse_attn, stable = stable ) self.stable = stable if stable: self.norm_by_max = DivideMax(dim = -1) self.to_logits = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, self.total_tokens), ) seq_range = torch.arange(seq_len) logits_range = torch.arange(total_tokens) seq_range = rearrange(seq_range, 'n -> () n ()') logits_range = rearrange(logits_range, 'd -> () () d') logits_mask = ( ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) | ((seq_range < text_seq_len) & (logits_range >= num_text_tokens)) ) self.register_buffer('logits_mask', logits_mask, persistent=False) self.loss_img_weight = loss_img_weight
def forward(self, x, mask=None): b, n, _, h, img_size, kernel_size, dilation, seq_len = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len softmax = torch.softmax img_seq_len = img_size**2 text_len = seq_len + 1 - img_seq_len # padding padding = seq_len - n + 1 mask = default(mask, lambda: torch.ones(b, text_len).bool()) x = F.pad(x, (0, 0, 0, padding), value=0) mask = mask[:, :text_len] # derive query / keys / values qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), qkv) q = q * self.scale ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map( lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) # text attention dots_text = einsum('b i d, b j d -> b i j', q_text, k_text) mask_value = max_neg_value(self.fp16) i, j = dots_text.shape[-2:] text_causal_mask = torch.ones(i, j).triu_(j - i + 1).bool() dots_text.masked_fill_(text_causal_mask, mask_value) attn_text = softmax(dots_text, dim=-1) out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) # image attention effective_kernel_size = (kernel_size - 1) * dilation + 1 padding = effective_kernel_size // 2 k_img, v_img = map( lambda t: rearrange(t, 'b (h w) c -> b c h w', h=img_size), (k_img, v_img)) k_img, v_img = map( lambda t: F.unfold( t, kernel_size, padding=padding, dilation=dilation), (k_img, v_img)) k_img, v_img = map( lambda t: rearrange(t, 'b (d j) i -> b i j d', j=kernel_size**2), (k_img, v_img)) # let image attend to all of text dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img) dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text) # calculate causal attention for local convolution i, j = dots_image.shape[-2:] img_seq = torch.arange(img_seq_len) k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h=img_size) k_img_indices = F.pad( k_img_indices, (padding, ) * 4, value=img_seq_len ) # padding set to be max, so it is never attended to k_img_indices = F.unfold(k_img_indices, kernel_size, dilation=dilation) k_img_indices = rearrange(k_img_indices, 'b j i -> b i j') # mask image attention q_img_indices = rearrange(img_seq, 'i -> () i ()') causal_mask = q_img_indices < k_img_indices # concat text mask with image causal mask causal_mask = repeat(causal_mask, '() i j -> b i j', b=b * h) mask = repeat(mask, 'b j -> (b h) i j', i=i, h=h) mask = torch.cat((~mask, causal_mask), dim=-1) # image can attend to all of text dots = torch.cat((dots_image_to_text, dots_image), dim=-1) dots.masked_fill_(mask, mask_value) attn = softmax(dots, dim=-1) # aggregate attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:] out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img) out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text) out_image = out_image_to_image + out_image_to_text # combine attended values for both text and image out = torch.cat((out_text, out_image), dim=1) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) out = self.to_out(out) return out[:, :n]
def forward( self, text, image = None, mask = None, return_loss = False ): assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})' device, total_seq_len = text.device, self.total_seq_len # make sure padding in text tokens get unique padding token id text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len) text = torch.where(text == 0, text_range, text) # add <bos> text = F.pad(text, (1, 0), value = 0) tokens = self.text_emb(text) tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device)) seq_len = tokens.shape[1] if exists(image) and not is_empty(image): is_raw_image = len(image.shape) == 4 if is_raw_image: image_size = self.vae.image_size assert tuple(image.shape[1:]) == (3, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training' image = self.vae.get_codebook_indices(image) image_len = image.shape[1] image_emb = self.image_emb(image) image_emb += self.image_pos_emb(image_emb) tokens = torch.cat((tokens, image_emb), dim = 1) seq_len += image_len # when training, if the length exceeds the total text + image length # remove the last token, since it needs not to be trained if tokens.shape[1] > total_seq_len: seq_len -= 1 tokens = tokens[:, :-1] out = self.transformer(tokens) if self.stable: out = self.norm_by_max(out) logits = self.to_logits(out) # mask logits to make sure text predicts text (except last token), and image predicts image logits_mask = self.logits_mask[:, :seq_len] max_neg_value = -torch.finfo(logits.dtype).max logits.masked_fill_(logits_mask, max_neg_value) if not return_loss: return logits assert exists(image), 'when training, image must be supplied' offsetted_image = image + self.num_text_tokens labels = torch.cat((text[:, 1:], offsetted_image), dim = 1) logits = rearrange(logits, 'b n c -> b c n') loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len]) loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:]) loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1) return loss
def expand_emb(self, r, dim_size): # Decompose and unsqueeze dimension r = rearrange(r, 'b (h x) i j -> b h x () i j', x=dim_size) expand_index = [-1, -1, -1, dim_size, -1, -1] # -1 indicates no expansion r = r.expand(expand_index) return rearrange(r, 'b h x1 x2 y1 y2 -> b h (x1 y1) (x2 y2)')
def __init__(self, *, dim, vae, num_text_tokens=10000, text_seq_len=256, depth, heads=8, dim_head=64, reversible=False, attn_dropout=0., ff_dropout=0): super().__init__() assert isinstance( vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE' num_image_tokens = vae.num_tokens image_seq_len = (vae.image_size // (2**vae.num_layers))**2 self.text_emb = nn.Embedding(num_text_tokens, dim) self.image_emb = nn.Embedding(num_image_tokens, dim) self.text_pos_emb = nn.Embedding(text_seq_len, dim) self.image_pos_emb = nn.Embedding(image_seq_len, dim) self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss self.num_image_tokens = num_image_tokens self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len seq_len = text_seq_len + image_seq_len total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS self.total_tokens = total_tokens self.vae = vae if exists(self.vae): self.vae = vae self.image_emb = vae.codebook self.transformer = Transformer(dim=dim, causal=True, depth=depth, heads=heads, dim_head=dim_head, reversible=reversible, attn_dropout=attn_dropout, ff_dropout=ff_dropout) self.to_logits = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, self.total_tokens), ) seq_range = torch.arange(seq_len) logits_range = torch.arange(total_tokens) seq_range = rearrange(seq_range, 'n -> () n ()') logits_range = rearrange(logits_range, 'd -> () () d') logits_mask = (((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) | ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) | ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1)))) self.register_buffer('logits_mask', logits_mask)
def make_imrange(arr: list): interpolation = torch.stack(arr) imgs = rearrange(make_grid(interpolation, 11), 'c h w -> h w c') imgs = imgs.cpu().detach().numpy() if torch.cuda.is_available( ) else imgs.detach().numpy() return imgs
def _beam_search( self, src: FloatTensor, mask: LongTensor, direction: str, beam_size: int, max_len: int, ) -> List[Hypothesis]: """run beam search for one direction Parameters ---------- src : FloatTensor [1, l, d] mask: LongTensor [1, l] direction : str one of "l2r" and "r2l" beam_size : int max_len : int Returns ------- List[Hypothesis] """ assert direction in {"l2r", "r2l"} assert ( src.size(0) == 1 and mask.size(0) == 1 ), f"beam search should only have single source, encounter with batch_size: {src.size(0)}" if direction == "l2r": start_w = vocab.SOS_IDX stop_w = vocab.EOS_IDX else: start_w = vocab.EOS_IDX stop_w = vocab.SOS_IDX hypotheses = torch.full( (1, max_len + 1), fill_value=vocab.PAD_IDX, dtype=torch.long, device=self.device, ) hypotheses[:, 0] = start_w hyp_scores = torch.zeros(1, dtype=torch.float, device=self.device) completed_hypotheses: List[Hypothesis] = [] t = 0 while len(completed_hypotheses) < beam_size and t < max_len: hyp_num = hypotheses.size(0) assert hyp_num <= beam_size, f"hyp_num: {hyp_num}, beam_size: {beam_size}" exp_src = repeat(src.squeeze(0), "s e -> b s e", b=hyp_num) exp_mask = repeat(mask.squeeze(0), "s -> b s", b=hyp_num) decode_outputs = self(exp_src, exp_mask, hypotheses)[:, t, :] log_p_t = F.log_softmax(decode_outputs, dim=-1) live_hyp_num = beam_size - len(completed_hypotheses) exp_hyp_scores = repeat(hyp_scores, "b -> b e", e=vocab_size) continuous_hyp_scores = rearrange(exp_hyp_scores + log_p_t, "b e -> (b e)") top_cand_hyp_scores, top_cand_hyp_pos = torch.topk( continuous_hyp_scores, k=live_hyp_num) prev_hyp_ids = top_cand_hyp_pos // vocab_size hyp_word_ids = top_cand_hyp_pos % vocab_size t += 1 new_hypotheses = [] new_hyp_scores = [] for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip( prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores): cand_new_hyp_score = cand_new_hyp_score.detach().item() hypotheses[prev_hyp_id, t] = hyp_word_id if hyp_word_id == stop_w: completed_hypotheses.append( Hypothesis( seq_tensor=hypotheses[prev_hyp_id, 1:t].detach(). clone(), # remove START_W at first score=cand_new_hyp_score, direction=direction, )) else: new_hypotheses.append( hypotheses[prev_hyp_id].detach().clone()) new_hyp_scores.append(cand_new_hyp_score) if len(completed_hypotheses) == beam_size: break hypotheses = torch.stack(new_hypotheses, dim=0) hyp_scores = torch.tensor(new_hyp_scores, dtype=torch.float, device=self.device) if len(completed_hypotheses) == 0: completed_hypotheses.append( Hypothesis( seq_tensor=hypotheses[0, 1:].detach().clone(), score=hyp_scores[0].detach().item(), direction=direction, )) return completed_hypotheses
def evaluate(checkpoint_path, eval_dir, database_json): model = SimpleMaskEstimator(513) model.load_checkpoint( checkpoint_path=checkpoint_path, in_checkpoint_path='model', consider_mpi=True ) model.eval() if dlp_mpi.IS_MASTER: print(f'Start to evaluate the checkpoint {checkpoint_path.resolve()} ' f'and will write the evaluation result to' f' {eval_dir / "result.json"}') database = JsonDatabase(database_json) test_dataset = get_test_dataset(database) with torch.no_grad(): summary = dict(masked=dict(), beamformed=dict(), observed=dict()) for batch in dlp_mpi.split_managed( test_dataset, is_indexable=True, progress_bar=True, allow_single_worker=True ): model_output = model(pt.data.example_to_device(batch)) example_id = batch['example_id'] s = batch['speech_source'][0][None] speech_mask = model_output['speech_mask_prediction'].numpy() Y = batch['observation_stft'] Z_mask = speech_mask[0] * Y[0] z_mask = pb.transform.istft(Z_mask)[None] speech_mask = np.median(speech_mask, axis=0).T noise_mask = model_output['noise_mask_prediction'].numpy() noise_mask = np.median(noise_mask, axis=0).T Y = rearrange(Y, 'c t f -> f c t') target_psd = pb_bss.extraction.get_power_spectral_density_matrix( Y, speech_mask, ) noise_psd = pb_bss.extraction.get_power_spectral_density_matrix( Y, noise_mask, ) beamformer = pb_bss.extraction.get_bf_vector( 'mvdr_souden', target_psd_matrix=target_psd, noise_psd_matrix=noise_psd ) Z_bf = pb_bss.extraction.apply_beamforming_vector(beamformer, Y).T z_bf = pb.transform.istft(Z_bf)[None] y = batch['observation'][0][None] s = s[:, :z_bf.shape[1]] for key, signal in zip(summary.keys(), [z_mask, z_bf, y]): signal = signal[:, :s.shape[1]] entry = pb_bss.evaluation.OutputMetrics( speech_prediction=signal, speech_source=s, sample_rate=16000 ).as_dict() entry.pop('mir_eval_selection') summary[key][example_id] = entry summary_list = dlp_mpi.COMM.gather(summary, root=dlp_mpi.MASTER) if dlp_mpi.IS_MASTER: print(f'\n len(summary_list): {len(summary_list)}') summary = dict(masked=dict(), beamformed=dict(), observed=dict()) for partial_summary in summary_list: for signal_type, metric in partial_summary.items(): summary[signal_type].update(metric) for signal_type, values in summary.items(): print(signal_type) for metric in next(iter(values.values())).keys(): mean = np.mean([value[metric] for key, value in values.items() if '_mean' not in key]) values[metric + '_mean'] = mean print(f'{metric}: {mean}') result_json_path = eval_dir / 'result.json' print(f"Exporting result: {result_json_path}") pb.io.dump_json(summary, result_json_path)
def train(train_loader, valid_loader, model, optim, criterion, num_epochs): print_every = 5 model.train() lowest_val = 1e9 train_losses = [] val_losses = [] total_step = 0 print("-" * 100) print("Starting Training") print("-" * 100) for epoch in range(num_epochs): pbar = tqdm(total=print_every, leave=False) total_loss = 0 for step, data_dict in enumerate(iter(train_loader)): total_step += 1 src, tgt, src_key_padding_mask, tgt_key_padding_mask = ( data_dict["ids1"], data_dict["ids2"], data_dict["masks_sent1"], data_dict["masks_sent2"], ) src, src_key_padding_mask = src.to("cpu"), src_key_padding_mask.to( "cpu") tgt, tgt_key_padding_mask = tgt.to("cpu"), tgt_key_padding_mask.to( "cpu") memory_key_padding_mask = src_key_padding_mask.clone() tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:] tgt_mask = gen_nopeek_mask(tgt_inp.shape[1]).to("cpu") optim.zero_grad() outputs = model( src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask, ) loss = criterion( rearrange(outputs, "b t v -> (b t) v"), rearrange(tgt_out, "b o -> (b o)"), ) loss.backward() optim.step_and_update_lr() total_loss += loss.item() train_losses.append((step, loss.item())) pbar.update(1) if step % print_every == print_every - 1: pbar.close() print( f"Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(train_loader)}] \t " f"Train Loss: {total_loss / print_every}") total_loss = 0 pbar = tqdm(total=print_every, leave=False) pbar.close() val_loss = validate(valid_loader, model, criterion) val_losses.append((total_step, val_loss)) if val_loss < lowest_val: lowest_val = val_loss torch.save(model, "output/transformer.pth") print(f"Val Loss: {val_loss}") return train_losses, val_losses