def train(self):
        self.model.train()

        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        meter_loss = AverageMeter('Loss', ':.4e')
        meter_loss_constr = AverageMeter('Constr', ':6.2f')
        meter_loss_perp = AverageMeter('Perplexity', ':6.2f')
        progress = ProgressMeter(
            self.training_loader.epoch_size()['__Video_0'], [
                batch_time, data_time, meter_loss, meter_loss_constr,
                meter_loss_perp
            ],
            prefix="Steps: [{}]".format(self.num_steps))

        data_iter = DALIGenericIterator(self.training_loader, ['data'],
                                        auto_reset=True)
        end = time.time()

        for i in range(self.start_steps, self.num_steps):
            # measure output loading time
            data_time.update(time.time() - end)

            try:
                images = next(data_iter)[0]['data']
            except StopIteration:
                data_iter.reset()
                images = next(data_iter)[0]['data']

            images = images.to('cuda')
            b, d, _, _, c = images.size()
            images = rearrange(images, 'b d h w c -> (b d) c h w')
            images = self.normalize(images.float() / 255.)
            # images = rearrange(images, '(b d) c h w -> b (d c) h w', b=b, d=d, c=c)
            self.optimizer.zero_grad()

            vq_loss, images_recon, perplexity = self.model(images)
            recon_error = F.mse_loss(images_recon, images)
            loss = recon_error + vq_loss
            loss.backward()

            self.optimizer.step()

            meter_loss_constr.update(recon_error.item(), 1)
            meter_loss_perp.update(perplexity.item(), 1)
            meter_loss.update(loss.item(), 1)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 20 == 0:
                progress.display(i)

            if i % 1000 == 0:
                print('saving ...')
                save_checkpoint(
                    self.folder_name, {
                        'steps': i,
                        'state_dict': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'scheduler': self.scheduler.state_dict()
                    }, 'checkpoint%s.pth.tar' % i)

                self.scheduler.step()
                images, images_recon = map(
                    lambda t: rearrange(
                        t, '(b d) c h w -> b d c h w', b=b, d=d, c=c),
                    [images, images_recon])
                images_orig, images_recs = train_visualize(
                    unnormalize=self.unnormalize,
                    images=images[0, :self.n_images_save],
                    n_images=self.n_images_save,
                    image_recs=images_recon[0, :self.n_images_save])

                save_images(file_name=os.path.join(self.path_img_orig,
                                                   f'image_{i}.png'),
                            image=images_orig)
                save_images(file_name=os.path.join(self.path_img_recs,
                                                   f'image_{i}.png'),
                            image=images_recs)

                if self.run_wandb:
                    logs = {
                        'iter': i,
                        'loss_recs': meter_loss_constr.val,
                        'loss': meter_loss.val,
                        'lr': self.scheduler.get_last_lr()[0]
                    }
                    self.run_wandb.log(logs)

        print('saving ...')
        save_checkpoint(
            self.folder_name, {
                'steps': self.num_steps,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict(),
            }, 'checkpoint%s.pth.tar' % self.num_steps)
Пример #2
0
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# training loop

for _ in range(NUM_BATCHES):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        batch = next(dl)
        seq, coords, mask = batch.seqs, batch.crds, batch.msks

        b, l, _ = seq.shape

        # prepare mask, labels

        seq, coords, mask = seq.argmax(dim = -1).to(DEVICE), coords.to(DEVICE), mask.to(DEVICE).bool()
        coords = rearrange(coords, 'b (l c) d -> b l c d', l = l)

        discretized_distances = get_bucketed_distance_matrix(coords[:, :, 0], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX)

        # predict

        distogram = model(seq, mask = mask)
        distogram = rearrange(distogram, 'b i j c -> b c i j')

        # loss

        loss = F.cross_entropy(
            distogram,
            discretized_distances,
            ignore_index = IGNORE_INDEX
        )
Пример #3
0
    def __init__(self,
                 path: str,
                 fields: Tuple[RawField, RawField, Field, Field, Field,
                               RawField],
                 max_samples: int = None,
                 use_mask: bool = True,
                 use_ground_truth=False,
                 **kwargs):
        print(f"Using ground truth: {use_ground_truth}", flush=True)
        """Create a SignTranslationDataset given paths and fields.

        Arguments:
            path: Common prefix of paths to the data files for both languages.
            exts: A tuple containing the extension to path for each language.
            fields: A tuple containing the fields that will be used for data
                in each language.
            Remaining keyword arguments: Passed to the constructor of
                data.Dataset.
        """
        if not isinstance(fields[0], (tuple, list)):
            fields = [("sequence", fields[0]), ("signer", fields[1]),
                      ("sgn", fields[2]), ("gls", fields[3]),
                      ("txt", fields[4]), ("use_mask", fields[5])]

        _t = time.time()
        samples = {}
        sample_count = 0
        longest_sequence = 0
        acc = 0
        frames = 0
        for annotation_file, annotation_type in path:
            tmp = load_dataset_file(annotation_file)
            for s in tmp:
                unlistify(s)
                if max_samples is not None and sample_count >= max_samples:  # For debugging to make a smaller dataset
                    break
                seq_id = s["name"]
                # Decompress and augment
                #if use_ground_truth: # Make one hot encoding TODO replace
                if use_ground_truth or annotation_type == "logits_max":
                    feature_length = fields[2][-1].pad_token.shape[0]
                    if isinstance(s['alignments']['pami0'], list):
                        s['alignments']['pami0'] = s['alignments']['pami0'][0]
                    labels = [
                        int(item) if feature_length > 1500 else
                        ((int(item) + 2) // 3)
                        for item in s['alignments']['pami0'].split(" ")
                    ]
                    chunk_length = 16
                    collapsed_labels = []
                    for start in range((len(labels) - chunk_length) + 1):
                        label = int(labels[start + (chunk_length // 2)])
                        collapsed_labels.append(label)
                    #collapsed_labels = [labels[0]]
                    #for label, prev_label in zip(labels[1:], labels[:-1]):
                    #    if label != prev_label:
                    #        collapsed_labels.append(labels)
                    labels = collapsed_labels
                    _feat = torch.zeros((len(labels), feature_length),
                                        dtype=torch.float32)
                    for index, label in enumerate(labels):
                        _feat[index, label] = 1
                if use_ground_truth:
                    _features = _feat
                else:
                    try:
                        features = pickle.loads(lzma.decompress(s['sign']))
                    except TypeError:
                        features = s['sign'].numpy()

                    _features = load_augment(annotation_type, features)
                    if len(_features.shape) == 3:
                        _features = rearrange(_features, "bs t f -> (bs t) f")
                    if annotation_type == "logits_max":
                        frames += len(labels)
                        assert _features.shape == _feat.shape, f"(logit) {_features.shape} =/= {_feat.shape} (gt)"
                        acc += (_features * _feat).sum()

                s["sign"] = _features
                longest_sequence = max(_features.shape[0], longest_sequence)
                if seq_id in samples:
                    if samples[seq_id]["sign"].shape[0] > s["sign"].shape[
                            0]:  # If there are less symbols then pad with zero filled
                        feature_size = s["sign"].shape[1]
                        difference = samples[seq_id]["sign"].shape[0] - s[
                            "sign"].shape[0]
                        s["sign"] = torch.cat([
                            torch.zeros(math.ceil(difference / 2),
                                        feature_size), s["sign"],
                            torch.zeros(math.floor(difference / 2),
                                        feature_size)
                        ],
                                              axis=0)

                    samples[seq_id]["loaded"] += 1
                    assert samples[seq_id]["name"] == s["name"]
                    assert samples[seq_id]["signer"] == s["signer"]
                    assert samples[seq_id]["gloss"] == s["gloss"]
                    assert samples[seq_id]["text"] == s["text"]
                    samples[seq_id]["sign"] = torch.cat(
                        [samples[seq_id]["sign"], s["sign"]], axis=1)
                else:
                    samples[seq_id] = {
                        "name": s["name"],
                        "signer": s["signer"],
                        "gloss": s["gloss"],
                        "text": s["text"],
                        "sign": s["sign"],
                        "loaded": 1,
                    }
                sample_count += 1
            if annotation_type == "logits_max":
                print(f"GT acc: {acc / frames}")
        print(f"Longest sequence: {longest_sequence}")
        print(f"Done loading samples in {time.time()-_t}s")
        examples = []
        for s in samples:
            sample = samples[s]
            if sample['loaded'] == len(path):
                examples.append(
                    data.Example.fromlist(
                        [
                            sample["name"],
                            sample["signer"],
                            # This is for numerical stability
                            sample["sign"] + 1e-8,
                            str(sample["gloss"]).strip(),
                            str(sample["text"]).strip(),
                            use_mask,
                        ],
                        fields,
                    ))
            else:
                print(
                    f"{s} only loaded {sample['loaded']} annotations so has been removed"
                )
        super().__init__(examples, fields, **kwargs)
Пример #4
0
    def forward(self,
                sequence: torch.Tensor,
                sequence_lengths: Optional[torch.Tensor] = None,
                may_deactivate_seq: bool = True) -> torch.Tensor:
        """

        Args:
            sequence (B, N, K, S): Chunked input sequence
            sequence_lengths (B): Sequence lengths along segment dimension (S)
            may_deactivate_seq: If set to `True`, the handling of sequence
                lengths is disabled when all examples in the batch have the
                same length

        """
        # The handling of sequence lengths can be disabled if all examples in a
        # batch have the same length and this length matches the size of the
        # time axis of the input sequence (i.e., the signal is not 0-padded)
        # This speeds up the computations
        if may_deactivate_seq and sequence_lengths is not None and (
                len(sequence_lengths) == 1
                or all(sequence_lengths[1:] == sequence_lengths[:-1])
        ) and sequence_lengths[0] == sequence.shape[-1]:
            sequence_lengths = None

        B, N, K, S = sequence.shape

        # LSTM only support 3-dim input. Reshape according to given shape
        lstm_in = rearrange(sequence, f'b n k s -> {self.lstm_reshape_to}')

        # Call lstm
        if sequence_lengths is not None:
            # TODO: don't hardcode this
            if 's' in self.lstm_reshape_to[:4]:
                packed = pack(rearrange(sequence, 'b n k s -> b s k n'),
                              sequence_lengths)
            else:
                assert self.lstm_reshape_to[1] == 'b'
                packed_sequence_lengths = rearrange(
                    sequence_lengths.reshape(B, 1, 1, 1).expand(B, 1, K, 1),
                    f'b n k s -> {self.lstm_reshape_to}').squeeze()
                packed = pack_padded_sequence(lstm_in,
                                              packed_sequence_lengths,
                                              batch_first=True)
        else:
            packed_sequence_lengths = None
            packed = lstm_in

        out = self.rnn(packed)
        if isinstance(out, tuple):
            out = out[0]

        if sequence_lengths is not None and 's' not in self.lstm_reshape_to[:4]:
            out, _ = pad_packed_sequence(out, batch_first=True, total_length=S)

        # FC projection layer
        out = self.fc(out)

        # Apply norm and rearrange back to BxNxKxS
        if sequence_lengths is not None and 's' in self.lstm_reshape_to[:4]:
            out = self.norm(out)
            out = rearrange(unpack(out, sequence_lengths),
                            'b s k n -> b n k s')
        else:
            out = apply_examplewise(self.norm, out, packed_sequence_lengths)
            out = rearrange(out,
                            f'{self.lstm_reshape_to} -> b n k s',
                            b=B,
                            s=S,
                            n=self.feat_size,
                            k=K)

        # Residual connection
        out = out + sequence

        return out
Пример #5
0
    def forward(self, x, mask=None):
        b, n, _, h, img_size, axis, seq_len = *x.shape, self.heads, self.image_size, self.axis, self.seq_len
        softmax = torch.softmax

        img_seq_len = img_size**2
        text_len = seq_len + 1 - img_seq_len

        # padding

        padding = seq_len - n + 1
        mask = default(mask, lambda: torch.ones(b, text_len).bool())

        x = F.pad(x, (0, 0, 0, padding), value=0)
        mask = mask[:, :text_len]

        # derive queries / keys / values

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
                      qkv)

        q = q * self.scale

        ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(
            lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))

        # text attention

        dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
        mask_value = max_neg_value(self.fp16)

        i, j = dots_text.shape[-2:]
        text_causal_mask = torch.ones(i, j).triu_(j - i + 1).bool()
        dots_text.masked_fill_(text_causal_mask, mask_value)

        attn_text = softmax(dots_text, dim=-1)
        out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)

        # image attention

        split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
        merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'

        # split out axis

        q_img, k_img, v_img = map(
            lambda t: rearrange(t, split_axis_einops, h=img_size),
            (q_img, k_img, v_img))

        # similarity

        dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img,
                                     k_img)
        dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)

        dots = torch.cat((dots_image_to_text, dots_image_to_image), dim=-1)

        # mask so image has full attention to text, but causal along axis

        bh, x, i, j = dots.shape
        causal_mask = torch.ones(i, img_size).triu_(img_size - i + 1).bool()
        causal_mask = repeat(causal_mask, 'i j -> b x i j', b=bh, x=x)

        mask = repeat(mask, 'b j -> (b h) x i j', h=h, x=x, i=i)
        mask = torch.cat((~mask, causal_mask), dim=-1)

        dots.masked_fill_(mask, mask_value)

        # attention.

        attn = softmax(dots, dim=-1)

        # aggregate

        attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[
            ..., text_len:]

        out_image_to_image = einsum('b x i j, b x j d -> b x i d',
                                    attn_image_to_image, v_img)
        out_image_to_text = einsum('b x i j, b j d -> b x i d',
                                   attn_image_to_text, v_text)

        out_image = out_image_to_image + out_image_to_text

        # merge back axis

        out_image = rearrange(out_image, merge_axis_einops, x=img_size)

        # combine attended values for both text and image

        out = torch.cat((out_text, out_image), dim=1)

        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        out = self.to_out(out)
        return out[:, :n]
Пример #6
0
 def split_bl_embeddings(self, embeddings):
     R, L = self.img_size**2, self.nh_size**2
     embeddings = rearrange(embeddings, '(b l) f -> b l f', l=L)
     return embeddings
    def apply(self,
              x,
              *,
              patch_size,
              k,
              downscale,
              scorer_has_se,
              normalization_str="identity",
              selection_method,
              selection_method_kwargs=None,
              selection_method_inference=None,
              patch_dropout=0.,
              hard_topk_probability=0.,
              random_patch_probability=0.,
              use_iterative_extraction,
              append_position_to_input,
              feature_network,
              aggregation_method,
              aggregation_method_kwargs=None,
              train):
        """Process a high resolution image by selecting a subset of useful patches.

    This model processes the input as follow:
    1. Compute scores per patch on a downscaled version of the input.
    2. Select "important" patches using sampling or top-k methods.
    3. Extract the patches from the high-resolution image.
    4. Compute representation vector for each patch with a feature network.
    5. Aggregate the patch representation to obtain an image representation.

    Args:
      x: Input tensor of shape (batch, height, witdh, channels).
      patch_size: Size of the (squared) patches to extract.
      k: Number of patches to extract per image.
      downscale: Downscale multiplier for the input of the scorer network.
      scorer_has_se: Whether scorer network has Squeeze-excite layers.
      normalization_str: String specifying the normalization of the scores.
      selection_method: Method that selects which patches should be extracted,
        based on their scores. Either returns indices (hard selection) or
        indicators vectors (which could yield interpolated patches).
      selection_method_kwargs: Keyword args for the selection_method.
      selection_method_inference: Selection method used at inference.
      patch_dropout: Probability to replace a patch by 0 values.
      hard_topk_probability: Probability to use the true topk on the scores to
        select the patches. This operation has no gradient so scorer's weights
        won't be trained.
      random_patch_probability: Probability to replace each patch by a random
        patch in the image during training.
      use_iterative_extraction: If True, uses a for loop instead of patch
        indexing for memory efficiency.
      append_position_to_input: Append normalized (height, width) position to
        the channels of the input.
      feature_network: Network to be applied on each patch individually to
        obtain patch representation vectors.
      aggregation_method: Method to aggregate the representations of the k
        patches of each image to obtain the image representation.
      aggregation_method_kwargs: Keywords arguments for aggregation_method.
      train: If the model is being trained. Disable dropout otherwise.

    Returns:
      A representation vector for each image in the batch.
    """
        selection_method = SelectionMethod(selection_method)
        aggregation_method = AggregationMethod(aggregation_method)
        if selection_method_inference:
            selection_method_inference = SelectionMethod(
                selection_method_inference)

        selection_method_kwargs = selection_method_kwargs or {}
        aggregation_method_kwargs = aggregation_method_kwargs or {}

        stats = {}

        # Compute new dimension of the scoring image.
        b, h, w, c = x.shape
        scoring_shape = (b, h // downscale, w // downscale, c)

        # === Compute the scores with a small CNN.
        if selection_method == SelectionMethod.RANDOM:
            scores_h, scores_w = Scorer.compute_output_size(
                h // downscale, w // downscale)
            num_patches = scores_h * scores_w
        else:
            # Downscale input to run scorer on.
            scoring_x = jax.image.resize(x, scoring_shape, method="bilinear")
            scores = Scorer(scoring_x,
                            use_squeeze_excite=scorer_has_se,
                            name="scorer")
            flatten_scores = einops.rearrange(scores, "b h w -> b (h w)")
            num_patches = flatten_scores.shape[-1]
            scores_h, scores_w = scores.shape[1:3]

            # Compute entropy before normalization
            prob_scores = jax.nn.softmax(flatten_scores)
            stats["entropy_before_normalization"] = jax.scipy.special.entr(
                prob_scores).sum(axis=1).mean(axis=0)

            # Normalize the flatten scores
            normalization_fn = create_normalization_fn(normalization_str)
            flatten_scores = normalization_fn(flatten_scores)
            scores = flatten_scores.reshape(scores.shape)
            stats["scores"] = scores[Ellipsis, None]

        # Concatenate height and width position to the input channels.
        if append_position_to_input:
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)
            c += 2

        # Overwrite the selection method at inference
        if selection_method_inference and not train:
            selection_method = selection_method_inference

        # === Patch selection

        # Select the patches by sampling or top-k. Some methods returns the indices
        # of the selected patches, other methods return indicator vectors.
        extract_by_indices = selection_method in [
            SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM
        ]
        if selection_method is SelectionMethod.SINKHORN_TOPK:
            indicators = select_patches_sinkhorn_topk(
                flatten_scores, k=k, **selection_method_kwargs)
        elif selection_method is SelectionMethod.PERTURBED_TOPK:
            sigma = selection_method_kwargs["sigma"]
            num_samples = selection_method_kwargs["num_samples"]
            sigma *= self.state("sigma_mutiplier",
                                shape=(),
                                initializer=nn.initializers.ones).value
            stats["sigma"] = sigma
            indicators = select_patches_perturbed_topk(flatten_scores,
                                                       k=k,
                                                       sigma=sigma,
                                                       num_samples=num_samples)
        elif selection_method is SelectionMethod.HARD_TOPK:
            indices = select_patches_hard_topk(flatten_scores, k=k)
        elif selection_method is SelectionMethod.RANDOM:
            batch_random_indices_fn = jax.vmap(
                functools.partial(jax.random.choice,
                                  a=num_patches,
                                  shape=(k, ),
                                  replace=False))
            indices = batch_random_indices_fn(
                jax.random.split(nn.make_rng(), b))

        # Compute scores entropy for regularization
        if selection_method not in [SelectionMethod.RANDOM]:
            prob_scores = flatten_scores
            # Normalize the scores if it is not already done.
            if "softmax" not in normalization_str:
                prob_scores = jax.nn.softmax(prob_scores)
            stats["entropy"] = jax.scipy.special.entr(prob_scores).sum(
                axis=1).mean(axis=0)

        # Randomly use hard topk at training.
        if (train and hard_topk_probability > 0 and selection_method
                not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]):
            true_indices = select_patches_hard_topk(flatten_scores, k=k)
            random_values = jax.random.uniform(nn.make_rng(), (b, ))
            use_hard = random_values < hard_topk_probability
            if extract_by_indices:
                indices = jnp.where(use_hard[:, None], true_indices, indices)
            else:
                true_indicators = make_indicators(true_indices, num_patches)
                indicators = jnp.where(use_hard[:, None, None],
                                       true_indicators, indicators)

        # Sample some random patches during training with random_patch_probability.
        if (train and random_patch_probability > 0
                and selection_method is not SelectionMethod.RANDOM):
            single_random_patches = functools.partial(jax.random.choice,
                                                      a=num_patches,
                                                      shape=(k, ),
                                                      replace=False)
            random_indices = jax.vmap(single_random_patches)(jax.random.split(
                nn.make_rng(), b))
            random_values = jax.random.uniform(nn.make_rng(), (b, k))
            use_random = random_values < random_patch_probability
            if extract_by_indices:
                indices = jnp.where(use_random, random_indices, indices)
            else:
                random_indicators = make_indicators(random_indices,
                                                    num_patches)
                indicators = jnp.where(use_random[:, None, :],
                                       random_indicators, indicators)

        # === Patch extraction
        if extract_by_indices:
            patches = extract_patches_from_indices(x,
                                                   indices,
                                                   patch_size=patch_size,
                                                   grid_shape=(scores_h,
                                                               scores_w))
            indicators = make_indicators(indices, num_patches)
        else:
            patches = extract_patches_from_indicators(
                x,
                indicators,
                patch_size,
                grid_shape=(scores_h, scores_w),
                iterative=use_iterative_extraction,
                patch_dropout=patch_dropout,
                train=train)

        chex.assert_shape(patches, (b, k, patch_size, patch_size, c))

        stats["extracted_patches"] = einops.rearrange(
            patches, "b k i j c -> b i (k j) c")
        # Remove position channels for plotting.
        if append_position_to_input:
            stats["extracted_patches"] = (
                stats["extracted_patches"][Ellipsis, :-2])

        # === Compute patch features
        flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c")
        representations = feature_network(flatten_patches, train=train)
        if representations.ndim > 2:
            collapse_axis = tuple(range(1, representations.ndim - 1))
            representations = representations.mean(axis=collapse_axis)
        representations = einops.rearrange(representations,
                                           "(b k) d -> b k d",
                                           k=k)

        stats["patch_representations"] = representations

        # === Aggregate the k patches

        # - for sampling we are forced to take an expectation
        # - for topk we have multiple options: mean, max, transformer.
        if aggregation_method is AggregationMethod.TRANSFORMER:
            patch_pos_encoding = nn.Dense(einops.rearrange(
                indicators, "b d k -> b k d"),
                                          features=representations.shape[-1])

            chex.assert_equal_shape([representations, patch_pos_encoding])
            representations += patch_pos_encoding
            representations = transformer.Transformer(
                representations,
                **aggregation_method_kwargs,
                is_training=train)

        elif aggregation_method is AggregationMethod.MEANPOOLING:
            representations = representations.mean(axis=1)
        elif aggregation_method is AggregationMethod.MAXPOOLING:
            representations = representations.max(axis=1)
        elif aggregation_method is AggregationMethod.SUM_LAYERNORM:
            representations = representations.sum(axis=1)
            representations = nn.LayerNorm(representations)

        representations = nn.Dense(representations,
                                   features=representations.shape[-1],
                                   name="classification_dense1")
        representations = nn.swish(representations)

        return representations, stats
Пример #8
0
 def get_codebook_indices(self, img):
     b = img.shape[0]
     img = (2 * img) - 1
     _, _, [_, _, indices] = self.model.encode(img)
     return rearrange(indices, '(b n) () -> b n', b=b)
Пример #9
0
    def forward(self, x):
        shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip

        rand_flip_fn = lambda t: torch.flip(t, dims=(-1, ))

        flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(
            prob_flip)
        flip_image_one_fn = rand_flip_fn if flip_image_one else identity
        flip_image_two_fn = rand_flip_fn if flip_image_two else identity

        cutout_coordinates_one, _ = cutout_coordinates(x,
                                                       self.cutout_ratio_range)
        cutout_coordinates_two, _ = cutout_coordinates(x,
                                                       self.cutout_ratio_range)

        image_one_cutout = cutout_and_resize(x,
                                             cutout_coordinates_one,
                                             mode=self.cutout_interpolate_mode)
        image_two_cutout = cutout_and_resize(x,
                                             cutout_coordinates_two,
                                             mode=self.cutout_interpolate_mode)

        image_one_cutout = flip_image_one_fn(image_one_cutout)
        image_two_cutout = flip_image_two_fn(image_two_cutout)

        image_one_cutout, image_two_cutout = self.augment1(
            image_one_cutout), self.augment2(image_two_cutout)

        proj_pixel_one, proj_instance_one = self.online_encoder(
            image_one_cutout)
        proj_pixel_two, proj_instance_two = self.online_encoder(
            image_two_cutout)

        image_h, image_w = shape[2:]

        proj_image_shape = proj_pixel_one.shape[2:]
        proj_image_h, proj_image_w = proj_image_shape

        coordinates = torch.meshgrid(torch.arange(image_h, device=device),
                                     torch.arange(image_w, device=device))

        coordinates = torch.stack(coordinates).unsqueeze(0).float()
        coordinates /= math.sqrt(image_h**2 + image_w**2)
        coordinates[:, 0] *= proj_image_h
        coordinates[:, 1] *= proj_image_w

        proj_coors_one = cutout_and_resize(
            coordinates,
            cutout_coordinates_one,
            output_size=proj_image_shape,
            mode=self.coord_cutout_interpolate_mode)
        proj_coors_two = cutout_and_resize(
            coordinates,
            cutout_coordinates_two,
            output_size=proj_image_shape,
            mode=self.coord_cutout_interpolate_mode)

        proj_coors_one = flip_image_one_fn(proj_coors_one)
        proj_coors_two = flip_image_two_fn(proj_coors_two)

        proj_coors_one, proj_coors_two = map(
            lambda t: rearrange(t, 'b c h w -> (b h w) c'),
            (proj_coors_one, proj_coors_two))
        pdist = nn.PairwiseDistance(p=2)

        num_pixels = proj_coors_one.shape[0]

        proj_coors_one_expanded = proj_coors_one[:, None].expand(
            num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
        proj_coors_two_expanded = proj_coors_two[None, :].expand(
            num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)

        distance_matrix = pdist(proj_coors_one_expanded,
                                proj_coors_two_expanded)
        distance_matrix = distance_matrix.reshape(num_pixels, num_pixels)

        positive_mask_one_two = distance_matrix < self.distance_thres
        positive_mask_two_one = positive_mask_one_two.t()

        with torch.no_grad():
            target_encoder = self._get_target_encoder()
            target_proj_pixel_one, target_proj_instance_one = target_encoder(
                image_one_cutout)
            target_proj_pixel_two, target_proj_instance_two = target_encoder(
                image_two_cutout)

        # flatten all the pixel projections

        flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')

        target_proj_pixel_one, target_proj_pixel_two = list(
            map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))

        # get total number of positive pixel pairs

        positive_pixel_pairs = positive_mask_one_two.sum()

        # get instance level loss

        pred_instance_one = self.online_predictor(proj_instance_one)
        pred_instance_two = self.online_predictor(proj_instance_two)

        loss_instance_one = loss_fn(pred_instance_one,
                                    target_proj_instance_two.detach())
        loss_instance_two = loss_fn(pred_instance_two,
                                    target_proj_instance_one.detach())

        instance_loss = (loss_instance_one + loss_instance_two).mean()

        if positive_pixel_pairs == 0:
            return instance_loss, 0

        if not self.use_pixpro:
            # calculate pix contrast loss

            proj_pixel_one, proj_pixel_two = list(
                map(flatten, (proj_pixel_one, proj_pixel_two)))

            similarity_one_two = F.cosine_similarity(
                proj_pixel_one[..., :, None],
                target_proj_pixel_two[..., None, :],
                dim=1) / self.similarity_temperature
            similarity_two_one = F.cosine_similarity(
                proj_pixel_two[..., :, None],
                target_proj_pixel_one[..., None, :],
                dim=1) / self.similarity_temperature

            loss_pix_one_two = -torch.log(
                similarity_one_two.masked_select(
                    positive_mask_one_two[None, ...]).exp().sum() /
                similarity_one_two.exp().sum())

            loss_pix_two_one = -torch.log(
                similarity_two_one.masked_select(
                    positive_mask_two_one[None, ...]).exp().sum() /
                similarity_two_one.exp().sum())

            pix_loss = (loss_pix_one_two + loss_pix_two_one) / 2
        else:
            # calculate pix pro loss

            propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
            propagated_pixels_two = self.propagate_pixels(proj_pixel_two)

            propagated_pixels_one, propagated_pixels_two = list(
                map(flatten, (propagated_pixels_one, propagated_pixels_two)))

            propagated_similarity_one_two = F.cosine_similarity(
                propagated_pixels_one[..., :, None],
                target_proj_pixel_two[..., None, :],
                dim=1)
            propagated_similarity_two_one = F.cosine_similarity(
                propagated_pixels_two[..., :, None],
                target_proj_pixel_one[..., None, :],
                dim=1)

            loss_pixpro_one_two = -propagated_similarity_one_two.masked_select(
                positive_mask_one_two[None, ...]).mean()
            loss_pixpro_two_one = -propagated_similarity_two_one.masked_select(
                positive_mask_two_one[None, ...]).mean()

            pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2

        # total loss

        loss = pix_loss * self.alpha + instance_loss
        return loss, positive_pixel_pairs
Пример #10
0
def wiener_filter_predict(observation, desired, filter_order, return_w=False):
    """
    Also known as projection of observation to desired
    (mir_eval.separation._project)

    w = argmin_w ( sum( |x * w - d|^2 ) )
    return x * w

    >>> from paderbox.utils.pretty import pprint
    >>> x = np.array([1, 2, 3, 4, 5])
    >>> y = np.array([1, 2, 1, 2, 1])
    >>> from mir_eval.separation import _project
    >>> _project(x[None], y, 2)
    array([ 0.41754386,  0.78596491,  1.15438596,  1.52280702,  1.89122807,
           -0.24561404])
    >>> wiener_filter_predict(x, y, 2)
    array([ 0.41754386,  0.78596491,  1.15438596,  1.52280702,  1.89122807,
           -0.24561404])
    >>> wiener_filter_predict(np.array([x]), y, 2)
    array([ 0.41754386,  0.78596491,  1.15438596,  1.52280702,  1.89122807,
           -0.24561404])
    >>> wiener_filter_predict(np.array([x, -x]), y, 2)
    array([ 0.41754386,  0.78596491,  1.15438596,  1.52280702,  1.89122807,
           -0.24561404])
    >>> _project(np.array([x, -x]), y, 2)
    array([ 0.41754386,  0.78596491,  1.15438596,  1.52280702,  1.89122807,
           -0.24561404])
    >>> pprint(wiener_filter_predict(np.array([x, y]), y, 2))
    array([1., 2., 1., 2., 1., 0.])
    >>> pprint(_project(np.array([x, y]), y, 2))
    array([1., 2., 1., 2., 1., 0.])

    """
    n_fft = int(2**np.ceil(
        np.log2(observation.shape[-1] + desired.shape[-1] - 1.)))

    if observation.ndim == 1:
        observation = observation[None, :]

    Observation = np.fft.rfft(observation, n=n_fft, axis=-1)
    Desired = np.fft.rfft(desired, n=n_fft, axis=-1)

    Autocorr = np.einsum('KT,kT->KkT', Observation.conj(), Observation)
    Crosscorr = np.einsum('KT,T->KT', Observation.conj(), Desired)

    autocorr = np.fft.irfft(Autocorr)
    crosscorr = np.fft.irfft(Crosscorr)

    R = np.array([[scipy.linalg.toeplitz(a[:filter_order]) for a in aa]
                  for aa in autocorr])
    R = einops.rearrange(
        R,
        'source1 source2 filter1 filter2 -> (source1 filter1) (source2 filter2)'
    )

    p = crosscorr[..., :filter_order]
    p = einops.rearrange(p, 'source filter -> (source filter)')

    from paderbox.math.solve import stable_solve

    w = np.squeeze(stable_solve(R, p[..., None]), axis=-1)
    w = einops.rearrange(w,
                         '(source filter) -> source filter',
                         filter=filter_order)

    if return_w:
        return w
    else:
        return np.sum([
            scipy.signal.fftconvolve(o, filter, axes=(-1))
            for o, filter in zip(observation, w)
        ],
                      axis=0)
Пример #11
0
 def get_codebook_indices(self, img):
     img = map_pixels(img)
     z_logits = self.enc.blocks(img)
     z = torch.argmax(z_logits, dim=1)
     return rearrange(z, 'b h w -> b (h w)')
Пример #12
0
    def forward(self, x):
        shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip

        rand_flip_fn = lambda t: torch.flip(t, dims=(-1, ))

        flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(
            prob_flip)
        flip_image_one_fn = rand_flip_fn if flip_image_one else identity
        flip_image_two_fn = rand_flip_fn if flip_image_two else identity

        cutout_coordinates_one, _ = cutout_coordinates(x,
                                                       self.cutout_ratio_range)
        cutout_coordinates_two, _ = cutout_coordinates(x,
                                                       self.cutout_ratio_range)

        image_one_cutout = cutout_and_resize(x,
                                             cutout_coordinates_one,
                                             mode=self.cutout_interpolate_mode)
        image_two_cutout = cutout_and_resize(x,
                                             cutout_coordinates_two,
                                             mode=self.cutout_interpolate_mode)

        image_one_cutout = flip_image_one_fn(image_one_cutout)
        image_two_cutout = flip_image_two_fn(image_two_cutout)

        image_one_cutout, image_two_cutout = self.augment1(
            image_one_cutout), self.augment2(image_two_cutout)

        self.aug1 = image_one_cutout.detach().clone()
        self.aug2 = image_two_cutout.detach().clone()

        proj_pixel_one, proj_instance_one = self.online_encoder(
            image_one_cutout)
        proj_pixel_two, proj_instance_two = self.online_encoder(
            image_two_cutout)

        proj_pixel_one, proj_pixel_two = get_shared_region(
            proj_pixel_one, proj_pixel_two, cutout_coordinates_one,
            cutout_coordinates_two, flip_image_one_fn, flip_image_two_fn,
            image_one_cutout.shape, self.cutout_interpolate_mode)
        if proj_pixel_one is None or proj_pixel_two is None:
            positive_pixel_pairs = 0
        else:
            positive_pixel_pairs = proj_pixel_one.shape[
                -1] * proj_pixel_one.shape[-2]

        with torch.no_grad():
            target_encoder = self._get_target_encoder()
            target_proj_pixel_one, target_proj_instance_one = target_encoder(
                image_one_cutout)
            target_proj_pixel_two, target_proj_instance_two = target_encoder(
                image_two_cutout)
            target_proj_pixel_one, target_proj_pixel_two = get_shared_region(
                target_proj_pixel_one, target_proj_pixel_two,
                cutout_coordinates_one, cutout_coordinates_two,
                flip_image_one_fn, flip_image_two_fn, image_one_cutout.shape,
                self.cutout_interpolate_mode)

        # If max_latent_dim is specified, stochastically extract latents from the shared areas.
        b, c, pp_h, pp_w = proj_pixel_one.shape
        if self.max_latent_dim and (pp_h * pp_w) > self.max_latent_dim:
            prob = torch.full((self.max_latent_dim, ),
                              1 / (self.max_latent_dim))
            latents = [
                proj_pixel_one, proj_pixel_two, target_proj_pixel_one,
                target_proj_pixel_two
            ]
            extracted = []
            for l in latents:
                l = l.reshape(b, c, pp_h * pp_w)
                l = l[:, :,
                      prob.multinomial(num_samples=self.max_latent_dim,
                                       replacement=False)]
                # For compatibility with the existing pixpro code, reshape this stochastic sampling back into a 2d "square".
                #  Note that the actual structure no longer matters going forwards. Pixels are only compared to themselves and others without regards
                #  to the original image structure.
                sqdim = int(math.sqrt(self.max_latent_dim))
                extracted.append(l.reshape(b, c, sqdim, sqdim))
            proj_pixel_one, proj_pixel_two, target_proj_pixel_one, target_proj_pixel_two = extracted

        # flatten all the pixel projections
        flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
        target_proj_pixel_one, target_proj_pixel_two = list(
            map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))

        # get instance level loss
        pred_instance_one = self.online_predictor(proj_instance_one)
        pred_instance_two = self.online_predictor(proj_instance_two)
        loss_instance_one = loss_fn(pred_instance_one,
                                    target_proj_instance_two.detach())
        loss_instance_two = loss_fn(pred_instance_two,
                                    target_proj_instance_one.detach())
        instance_loss = (loss_instance_one + loss_instance_two).mean()

        if positive_pixel_pairs == 0:
            return instance_loss, 0

        # calculate pix pro loss
        propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
        propagated_pixels_two = self.propagate_pixels(proj_pixel_two)

        propagated_pixels_one, propagated_pixels_two = list(
            map(flatten, (propagated_pixels_one, propagated_pixels_two)))

        propagated_similarity_one_two = F.cosine_similarity(
            propagated_pixels_one[..., :, None],
            target_proj_pixel_two[..., None, :],
            dim=1)
        propagated_similarity_two_one = F.cosine_similarity(
            propagated_pixels_two[..., :, None],
            target_proj_pixel_one[..., None, :],
            dim=1)

        loss_pixpro_one_two = -propagated_similarity_one_two.mean()
        loss_pixpro_two_one = -propagated_similarity_two_one.mean()

        pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2

        return instance_loss, pix_loss, positive_pixel_pairs
Пример #13
0
 def fwd_classification(self, x, seq_len=None):
     x = rearrange(x, 'b f t -> b t f')
     x = self._rnn_fwd(x, seq_len)
     x = rearrange(x, 'b t f -> b f t')
     y, seq_len_y = self._clf_fwd(x, seq_len)
     return nn.Sigmoid()(y), seq_len_y
                            shuffle=True,
                            num_workers=4)

    # training
    min_running_loss = np.inf
    for epoch in range(EPOCHS):
        running_loss = 0.0

        for i, batch in tqdm(enumerate(dataloader)):
            # zero the parameters gradient
            optimizer.zero_grad()

            # forward pass
            inputs = batch['image'].to(device)
            with torch.no_grad():
                targets = rearrange(resnet18(inputs),
                                    'b vec h w -> b (vec h w)')  # h=w=1
                #targets = torch.squeeze(resnet18(inputs))
            outputs = teacher(inputs)
            loss = distillation_loss(outputs,
                                     targets) + compactness_loss(outputs)

            # backward pass
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # print stats
        print(f"Epoch {epoch+1}, iter {i+1} \t loss: {running_loss}")

        if running_loss < min_running_loss:
            print(f"Loss decreased: {min_running_loss} -> {running_loss}.")
Пример #15
0
    def forward(self, feats, coors, edges=None, mask=None, adj_mat=None):
        b, n, d, device, fourier_features, num_nearest, valid_radius, only_sparse_neighbors = *feats.shape, feats.device, self.fourier_features, self.num_nearest_neighbors, self.valid_radius, self.only_sparse_neighbors

        if exists(mask):
            num_nodes = mask.sum(dim=-1)

        use_nearest = num_nearest > 0 or only_sparse_neighbors

        rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(
            coors, 'b j d -> b () j d')
        rel_dist = (rel_coors**2).sum(dim=-1, keepdim=True)

        i = j = n

        if use_nearest:
            ranking = rel_dist[..., 0]

            if exists(mask):
                rank_mask = mask[:, None, :] * mask[:, None, :]
                ranking.masked_fill_(~rank_mask, 1e5)

            if exists(adj_mat):
                if len(adj_mat.shape) == 2:
                    adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b=b)

                if only_sparse_neighbors:
                    num_nearest = int(adj_mat.float().sum(dim=-1).max().item())
                    valid_radius = 0

                self_mask = rearrange(
                    torch.eye(n, device=device, dtype=torch.bool),
                    'i j -> () i j')

                adj_mat = adj_mat.masked_fill(self_mask, False)
                ranking.masked_fill_(self_mask, -1.)
                ranking.masked_fill_(adj_mat, 0.)

            nbhd_ranking, nbhd_indices = ranking.topk(num_nearest,
                                                      dim=-1,
                                                      largest=False)

            nbhd_mask = nbhd_ranking <= valid_radius

            rel_coors = batched_index_select(rel_coors, nbhd_indices, dim=2)
            rel_dist = batched_index_select(rel_dist, nbhd_indices, dim=2)

            if exists(edges):
                edges = batched_index_select(edges, nbhd_indices, dim=2)

            j = num_nearest

        if fourier_features > 0:
            rel_dist = fourier_encode_dist(rel_dist,
                                           num_encodings=fourier_features)
            rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d')

        if use_nearest:
            feats_j = batched_index_select(feats, nbhd_indices, dim=1)
        else:
            feats_j = rearrange(feats, 'b j d -> b () j d')

        feats_i = rearrange(feats, 'b i d -> b i () d')
        feats_i, feats_j = broadcast_tensors(feats_i, feats_j)

        edge_input = torch.cat((feats_i, feats_j, rel_dist), dim=-1)

        if exists(edges):
            edge_input = torch.cat((edge_input, edges), dim=-1)

        m_ij = self.edge_mlp(edge_input)

        if exists(mask):
            mask_i = rearrange(mask, 'b i -> b i ()')

            if use_nearest:
                mask_j = batched_index_select(mask, nbhd_indices, dim=1)
                mask = (mask_i * mask_j) & nbhd_mask
            else:
                mask_j = rearrange(mask, 'b j -> b () j')
                mask = mask_i * mask_j

        if exists(self.coors_mlp):
            coor_weights = self.coors_mlp(m_ij)
            coor_weights = rearrange(coor_weights, 'b i j () -> b i j')

            rel_coors = self.coors_norm(rel_coors)

            if exists(mask):
                coor_weights.masked_fill_(~mask, 0.)

            coors_out = einsum('b i j, b i j c -> b i c', coor_weights,
                               rel_coors) + coors
        else:
            coors_out = coors

        if exists(self.node_mlp):
            m_ij_mask = rearrange(mask, '... -> ... ()')
            m_ij = m_ij.masked_fill(~m_ij_mask, 0.)

            if self.m_pool_method == 'mean':
                if exists(mask):
                    # masked mean
                    mask_sum = m_ij_mask.sum(dim=-2)
                    m_i = safe_div(m_ij.sum(dim=-2), mask_sum)
                else:
                    m_i = m_ij.mean(dim=-2)

            elif self.m_pool_method == 'sum':
                m_i = m_ij.sum(dim=-2)

            normed_feats = self.node_norm(feats)
            node_mlp_input = torch.cat((normed_feats, m_i), dim=-1)
            node_out = self.node_mlp(node_mlp_input) + feats
        else:
            node_out = feats

        return node_out, coors_out
Пример #16
0
    def forward(self, x, B, T, W):
        num_spatial_tokens = (x.size(1) - 1) // T
        H = num_spatial_tokens // W

        if self.attention_type in ['space_only', 'joint_space_time']:
            x = x + self.drop_path(self.attn(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
            return x
        elif self.attention_type == 'divided_space_time':
            ## Temporal
            xt = x[:, 1:, :]
            xt = rearrange(xt,
                           'b (h w t) m -> (b h w) t m',
                           b=B,
                           h=H,
                           w=W,
                           t=T)
            res_temporal = self.drop_path(
                self.temporal_attn(self.temporal_norm1(xt)))
            res_temporal = rearrange(res_temporal,
                                     '(b h w) t m -> b (h w t) m',
                                     b=B,
                                     h=H,
                                     w=W,
                                     t=T)
            res_temporal = self.temporal_fc(res_temporal)
            xt = x[:, 1:, :] + res_temporal

            ## Spatial
            init_cls_token = x[:, 0, :].unsqueeze(1)
            cls_token = init_cls_token.repeat(1, T, 1)
            cls_token = rearrange(cls_token, 'b t m -> (b t) m', b=B,
                                  t=T).unsqueeze(1)
            xs = xt
            xs = rearrange(xs,
                           'b (h w t) m -> (b t) (h w) m',
                           b=B,
                           h=H,
                           w=W,
                           t=T)
            xs = torch.cat((cls_token, xs), 1)
            res_spatial = self.drop_path(self.attn(self.norm1(xs)))

            ### Taking care of CLS token
            cls_token = res_spatial[:, 0, :]
            cls_token = rearrange(cls_token, '(b t) m -> b t m', b=B, t=T)
            cls_token = torch.mean(cls_token, 1,
                                   True)  ## averaging for every frame
            res_spatial = res_spatial[:, 1:, :]
            res_spatial = rearrange(res_spatial,
                                    '(b t) (h w) m -> b (h w t) m',
                                    b=B,
                                    h=H,
                                    w=W,
                                    t=T)
            res = res_spatial
            x = xt

            ## Mlp
            x = torch.cat((init_cls_token, x), 1) + torch.cat(
                (cls_token, res), 1)
            x = x + self.drop_path(self.mlp(self.norm2(x)))
            return x
Пример #17
0
 def form_input_patches(self, patches):
     patches = rearrange(patches, 'n b l c h w -> (n b l) c h w')
     return patches
Пример #18
0
 def rs(array):
     h = int(np.sqrt(len(array)))
     return np.around(
         rearrange(array, '(h w) -> h w', h=h).cpu().numpy(), 3)
def extract_patches_from_indicators(x,
                                    indicators,
                                    patch_size,
                                    patch_dropout,
                                    grid_shape,
                                    train,
                                    iterative=False):
    """Extract patches from a batch of images.

  Args:
    x: The batch of images of shape (batch, height, width, channels).
    indicators: The one hot indicators of shape (batch, num_patches, k).
    patch_size: The size of the (squared) patches to extract.
    patch_dropout: Probability to replace a patch by 0 values.
    grid_shape: Pair of height, width of the disposition of the num_patches
      patches.
    train: If the model is being trained. Disable dropout if not.
    iterative: If True, etracts the patches with a for loop rather than
      instanciating the "all patches" tensor and extracting by dotproduct with
      indicators. `iterative` is more memory efficient.

  Returns:
    The patches extracted from x with shape
      (batch, k, patch_size, patch_size, channels).

  """
    batch_size, height, width, channels = x.shape
    scores_h, scores_w = grid_shape
    k = indicators.shape[-1]
    indicators = einops.rearrange(indicators,
                                  "b (h w) k -> b k h w",
                                  h=scores_h,
                                  w=scores_w)

    scale_height = height // scores_h
    scale_width = width // scores_w
    padded_height = scale_height * scores_h + patch_size - 1
    padded_width = scale_width * scores_w + patch_size - 1
    top_pad = (patch_size - scale_height) // 2
    left_pad = (patch_size - scale_width) // 2
    bottom_pad = padded_height - top_pad - height
    right_pad = padded_width - left_pad - width

    # TODO(jbcdnr): assert padding is positive.

    padded_x = jnp.pad(x, [(0, 0), (top_pad, bottom_pad),
                           (left_pad, right_pad), (0, 0)])

    # Extract the patches. Iterative fits better in memory as it does not
    # instanciate the "all patches" tensor but iterate over them to compute the
    # weighted sum with the indicator variables from topk.
    if not iterative:
        assert patch_dropout == 0., "Patch dropout not implemented."
        patches = utils.extract_images_patches(padded_x,
                                               window_size=(patch_size,
                                                            patch_size),
                                               stride=(scale_height,
                                                       scale_width))

        shape = (batch_size, scores_h, scores_w, patch_size, patch_size,
                 channels)
        chex.assert_shape(patches, shape)

        patches = jnp.einsum("b k h w, b h w i j c -> b k i j c", indicators,
                             patches)

    else:
        mask = jnp.ones((batch_size, scores_h, scores_w))
        mask = nn.dropout(mask, patch_dropout, deterministic=not train)

        def accumulate_patches(acc, index_i_j):
            i, j = index_i_j
            patch = jax.lax.dynamic_slice(
                padded_x, (0, i * scale_height, j * scale_width, 0),
                (batch_size, patch_size, patch_size, channels))
            weights = indicators[:, :, i, j]

            is_masked = mask[:, i, j]
            weighted_patch = jnp.einsum("b, bk, bijc -> bkijc", is_masked,
                                        weights, patch)
            chex.assert_equal_shape([acc, weighted_patch])
            acc += weighted_patch
            return acc, None

        indices = jnp.stack(jnp.meshgrid(jnp.arange(scores_h),
                                         jnp.arange(scores_w),
                                         indexing="ij"),
                            axis=-1)
        indices = indices.reshape((-1, 2))
        init_patches = jnp.zeros(
            (batch_size, k, patch_size, patch_size, channels))
        patches, _ = jax.lax.scan(accumulate_patches, init_patches, indices)

    return patches
Пример #20
0
 def rs(array):
     h = int(np.sqrt(len(array)))
     return rearrange(array, '(h w) -> h w', h=h)
Пример #21
0
def segment(
    signal: torch.Tensor,
    hop_size: int,
    window_size: int,
    sequence_lengths: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    Zero-pads and segments the input sequence `signal` along the time dimension `L` (-2).

    Examples:
        >>> import torch
        >>> hop_size = 10
        >>> segmented, _ = segment(torch.randn(1, 50, 3), hop_size, 2 * hop_size)

        # Shape is BxNxKxS (batch x feat x win x frames)
        >>> segmented.shape
        torch.Size([1, 3, 20, 6])

        # The first block is zero-padded with hop_size
        >>> segmented[..., :hop_size, 0]
        tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

        # Last block as well
        >>> segmented[..., -hop_size:, -1]
        tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

        # Sequence lengths are computed
        >>> segmented, sequence_lengths = segment(torch.cat([torch.randn(1, 30, 3), torch.zeros(1, 10, 3)], dim=1),
        ...                                         hop_size, 2*hop_size, torch.tensor([30]))
        >>> sequence_lengths
        tensor([4])

        # All data outside of sequence_lengths is zero
        >>> segmented[0, ..., sequence_lengths[0]:].flatten()
        tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

        # And the last segment within seuqence_lengths contains data, but zero padded at the end
        # (Conversion to uint8 is to make the doctest compatible with all PyTorch versions)
        >>> (segmented[0, ..., sequence_lengths[0] - 1] == 0).type(torch.uint8)
        tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
               dtype=torch.uint8)

        # Test the corner-cases for computation of sequence lengths

        # One above exact match
        >>> segment(1 + torch.arange(5)[None, :, None], 2, 4, torch.tensor(5))
        (tensor([[[[0, 1, 3, 5],
                  [0, 2, 4, 0],
                  [1, 3, 5, 0],
                  [2, 4, 0, 0]]]]), tensor(4))

        # Exact match
        >>> segment(1 + torch.arange(5)[None, :, None], 2, 4, torch.tensor(4))
        (tensor([[[[0, 1, 3, 5],
                  [0, 2, 4, 0],
                  [1, 3, 5, 0],
                  [2, 4, 0, 0]]]]), tensor(3))
        >>> segment(1 + torch.arange(4)[None, :, None], 2, 4, torch.tensor(4))
        (tensor([[[[0, 1, 3],
                  [0, 2, 4],
                  [1, 3, 0],
                  [2, 4, 0]]]]), tensor(3))

        # One below exact match
        >>> segment(1 + torch.arange(5)[None, :, None], 2, 4, torch.tensor(3))
        (tensor([[[[0, 1, 3, 5],
                  [0, 2, 4, 0],
                  [1, 3, 5, 0],
                  [2, 4, 0, 0]]]]), tensor(3))
        >>> segment(1 + torch.arange(3)[None, :, None], 2, 4, torch.tensor(3))
        (tensor([[[[0, 1, 3],
                  [0, 2, 0],
                  [1, 3, 0],
                  [2, 0, 0]]]]), tensor(3))

        # Shift != size // 2
        >>> segmented, seq_len = segment(torch.arange(5)[None, :, None], 3, 4, torch.tensor(5))
        >>> segmented.shape
        torch.Size([1, 1, 4, 2])
        >>> seq_len
        tensor(2)
        >>> segmented, seq_len = segment(torch.arange(5)[None, :, None], 1, 4, torch.tensor(5))
        >>> segmented.shape
        torch.Size([1, 1, 4, 8])
        >>> seq_len
        tensor(8)

        >>> segmented, seq_len = segment(torch.ones(1, 7912, 64), 50, 100, torch.tensor([7912]))
        >>> segmented.shape
        torch.Size([1, 64, 100, 160])
        >>> seq_len
        tensor([160])


    Args:
        signal ([Bx]LxN): 2D input signal with optional batch dimension
        hop_size: Hop size P
        window_size: Window size K
        sequence_lengths: These are not used for segmentation, but if provided, the resulting sequence lengths along the
            segment (S) dimension are returned in addition to the segmented signal. Then, the sequence length is the
            number of blocks that contain any part of the signal, and these might be 0-padded.

    Returns:
        [Bx]NxKxS
        S is the number of frames, K is the window size, N is the feature size
    """
    # Add padding for the first and last blocks. Should be each hop_size so
    # that the first half of the first block and the last half of the last
    # block are filled with 0s for the case of 50% overlap.
    padding = window_size - hop_size
    signal = F.pad(signal, [0, 0, padding, padding])

    segmented = pb.array.segment_axis(signal,
                                      window_size,
                                      hop_size,
                                      axis=-2,
                                      end='pad')
    segmented = rearrange(segmented, '... s k n -> ... n k s')

    if sequence_lengths is not None:
        sequence_lengths = sequence_lengths + 2 * padding
        sequence_lengths = (sequence_lengths - padding)
        sequence_lengths = (sequence_lengths - 1) // hop_size + 1
    return segmented, sequence_lengths
Пример #22
0
    def __init__(
        self,
        *,
        dim,
        vae,
        num_text_tokens = 10000,
        text_seq_len = 256,
        depth,
        heads = 8,
        dim_head = 64,
        reversible = False,
        attn_dropout = 0.,
        ff_dropout = 0,
        sparse_attn = False,
        attn_types = None,
        loss_img_weight = 7,
        stable = False
    ):
        super().__init__()
        assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024)), 'vae must be an instance of DiscreteVAE'

        image_size = vae.image_size
        num_image_tokens = vae.num_tokens
        image_fmap_size = (vae.image_size // (2 ** vae.num_layers))
        image_seq_len = image_fmap_size ** 2

        num_text_tokens = num_text_tokens + text_seq_len  # reserve unique padding tokens for each position (text seq len)

        self.text_emb = nn.Embedding(num_text_tokens, dim)
        self.image_emb = nn.Embedding(num_image_tokens, dim)

        self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) # +1 for <bos>
        self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size))

        self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
        self.num_image_tokens = num_image_tokens

        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        seq_len = text_seq_len + image_seq_len
        total_tokens = num_text_tokens + num_image_tokens
        self.total_tokens = total_tokens
        self.total_seq_len = seq_len

        self.vae = vae
        set_requires_grad(self.vae, False) # freeze VAE from being trained

        self.transformer = Transformer(
            dim = dim,
            causal = True,
            seq_len = seq_len,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            reversible = reversible,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            attn_types = attn_types,
            image_fmap_size = image_fmap_size,
            sparse_attn = sparse_attn,
            stable = stable
        )

        self.stable = stable

        if stable:
            self.norm_by_max = DivideMax(dim = -1)

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_tokens),
        )

        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        logits_mask = (
            ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
            ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
        )

        self.register_buffer('logits_mask', logits_mask, persistent=False)
        self.loss_img_weight = loss_img_weight
Пример #23
0
    def forward(self, x, mask=None):
        b, n, _, h, img_size, kernel_size, dilation, seq_len = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len
        softmax = torch.softmax

        img_seq_len = img_size**2
        text_len = seq_len + 1 - img_seq_len

        # padding

        padding = seq_len - n + 1
        mask = default(mask, lambda: torch.ones(b, text_len).bool())

        x = F.pad(x, (0, 0, 0, padding), value=0)
        mask = mask[:, :text_len]

        # derive query / keys / values

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
                      qkv)

        q = q * self.scale

        ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(
            lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))

        # text attention

        dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
        mask_value = max_neg_value(self.fp16)

        i, j = dots_text.shape[-2:]
        text_causal_mask = torch.ones(i, j).triu_(j - i + 1).bool()
        dots_text.masked_fill_(text_causal_mask, mask_value)

        attn_text = softmax(dots_text, dim=-1)
        out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)

        # image attention

        effective_kernel_size = (kernel_size - 1) * dilation + 1
        padding = effective_kernel_size // 2

        k_img, v_img = map(
            lambda t: rearrange(t, 'b (h w) c -> b c h w', h=img_size),
            (k_img, v_img))
        k_img, v_img = map(
            lambda t: F.unfold(
                t, kernel_size, padding=padding, dilation=dilation),
            (k_img, v_img))
        k_img, v_img = map(
            lambda t: rearrange(t, 'b (d j) i -> b i j d', j=kernel_size**2),
            (k_img, v_img))

        # let image attend to all of text

        dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
        dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)

        # calculate causal attention for local convolution

        i, j = dots_image.shape[-2:]
        img_seq = torch.arange(img_seq_len)
        k_img_indices = rearrange(img_seq.float(),
                                  '(h w) -> () () h w',
                                  h=img_size)
        k_img_indices = F.pad(
            k_img_indices, (padding, ) * 4, value=img_seq_len
        )  # padding set to be max, so it is never attended to
        k_img_indices = F.unfold(k_img_indices, kernel_size, dilation=dilation)
        k_img_indices = rearrange(k_img_indices, 'b j i -> b i j')

        # mask image attention

        q_img_indices = rearrange(img_seq, 'i -> () i ()')
        causal_mask = q_img_indices < k_img_indices

        # concat text mask with image causal mask

        causal_mask = repeat(causal_mask, '() i j -> b i j', b=b * h)
        mask = repeat(mask, 'b j -> (b h) i j', i=i, h=h)
        mask = torch.cat((~mask, causal_mask), dim=-1)

        # image can attend to all of text

        dots = torch.cat((dots_image_to_text, dots_image), dim=-1)
        dots.masked_fill_(mask, mask_value)

        attn = softmax(dots, dim=-1)

        # aggregate

        attn_image_to_text, attn_image = attn[..., :text_len], attn[...,
                                                                    text_len:]

        out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image,
                                    v_img)
        out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text,
                                   v_text)

        out_image = out_image_to_image + out_image_to_text

        # combine attended values for both text and image

        out = torch.cat((out_text, out_image), dim=1)

        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        out = self.to_out(out)
        return out[:, :n]
Пример #24
0
    def forward(
        self,
        text,
        image = None,
        mask = None,
        return_loss = False
    ):
        assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
        device, total_seq_len = text.device, self.total_seq_len

        # make sure padding in text tokens get unique padding token id

        text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len)
        text = torch.where(text == 0, text_range, text)

        # add <bos>

        text = F.pad(text, (1, 0), value = 0)

        tokens = self.text_emb(text)
        tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))

        seq_len = tokens.shape[1]

        if exists(image) and not is_empty(image):
            is_raw_image = len(image.shape) == 4

            if is_raw_image:
                image_size = self.vae.image_size
                assert tuple(image.shape[1:]) == (3, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'

                image = self.vae.get_codebook_indices(image)

            image_len = image.shape[1]
            image_emb = self.image_emb(image)

            image_emb += self.image_pos_emb(image_emb)

            tokens = torch.cat((tokens, image_emb), dim = 1)

            seq_len += image_len

        # when training, if the length exceeds the total text + image length
        # remove the last token, since it needs not to be trained

        if tokens.shape[1] > total_seq_len:
            seq_len -= 1
            tokens = tokens[:, :-1]

        out = self.transformer(tokens)

        if self.stable:
            out = self.norm_by_max(out)

        logits = self.to_logits(out)

        # mask logits to make sure text predicts text (except last token), and image predicts image

        logits_mask = self.logits_mask[:, :seq_len]
        max_neg_value = -torch.finfo(logits.dtype).max
        logits.masked_fill_(logits_mask, max_neg_value)

        if not return_loss:
            return logits

        assert exists(image), 'when training, image must be supplied'

        offsetted_image = image + self.num_text_tokens
        labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)

        logits = rearrange(logits, 'b n c -> b c n')

        loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
        loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])

        loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
        return loss
Пример #25
0
 def expand_emb(self, r, dim_size):
     # Decompose and unsqueeze dimension
     r = rearrange(r, 'b (h x) i j -> b h x () i j', x=dim_size)
     expand_index = [-1, -1, -1, dim_size, -1, -1]  # -1 indicates no expansion
     r = r.expand(expand_index)
     return rearrange(r, 'b h x1 x2 y1 y2 -> b h (x1 y1) (x2 y2)')
Пример #26
0
    def __init__(self,
                 *,
                 dim,
                 vae,
                 num_text_tokens=10000,
                 text_seq_len=256,
                 depth,
                 heads=8,
                 dim_head=64,
                 reversible=False,
                 attn_dropout=0.,
                 ff_dropout=0):
        super().__init__()
        assert isinstance(
            vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'

        num_image_tokens = vae.num_tokens
        image_seq_len = (vae.image_size // (2**vae.num_layers))**2

        self.text_emb = nn.Embedding(num_text_tokens, dim)
        self.image_emb = nn.Embedding(num_image_tokens, dim)

        self.text_pos_emb = nn.Embedding(text_seq_len, dim)
        self.image_pos_emb = nn.Embedding(image_seq_len, dim)

        self.num_text_tokens = num_text_tokens  # for offsetting logits index and calculating cross entropy loss
        self.num_image_tokens = num_image_tokens

        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        seq_len = text_seq_len + image_seq_len
        total_tokens = num_text_tokens + num_image_tokens + 1  # extra for EOS
        self.total_tokens = total_tokens

        self.vae = vae
        if exists(self.vae):
            self.vae = vae
            self.image_emb = vae.codebook

        self.transformer = Transformer(dim=dim,
                                       causal=True,
                                       depth=depth,
                                       heads=heads,
                                       dim_head=dim_head,
                                       reversible=reversible,
                                       attn_dropout=attn_dropout,
                                       ff_dropout=ff_dropout)

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_tokens),
        )

        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        logits_mask = (((seq_range >= (text_seq_len - 1)) &
                        (logits_range < num_text_tokens)) |
                       ((seq_range < (text_seq_len - 1)) &
                        (logits_range >= num_text_tokens)) |
                       ((seq_range != (seq_len - 1)) & (logits_range >=
                                                        (total_tokens - 1))))

        self.register_buffer('logits_mask', logits_mask)
Пример #27
0
def make_imrange(arr: list):
    interpolation = torch.stack(arr)
    imgs = rearrange(make_grid(interpolation, 11), 'c h w -> h w c')
    imgs = imgs.cpu().detach().numpy() if torch.cuda.is_available(
    ) else imgs.detach().numpy()
    return imgs
Пример #28
0
    def _beam_search(
        self,
        src: FloatTensor,
        mask: LongTensor,
        direction: str,
        beam_size: int,
        max_len: int,
    ) -> List[Hypothesis]:
        """run beam search for one direction

        Parameters
        ----------
        src : FloatTensor
            [1, l, d]
        mask: LongTensor
            [1, l]
        direction : str
            one of "l2r" and "r2l"
        beam_size : int
        max_len : int

        Returns
        -------
        List[Hypothesis]
        """
        assert direction in {"l2r", "r2l"}
        assert (
            src.size(0) == 1 and mask.size(0) == 1
        ), f"beam search should only have single source, encounter with batch_size: {src.size(0)}"

        if direction == "l2r":
            start_w = vocab.SOS_IDX
            stop_w = vocab.EOS_IDX
        else:
            start_w = vocab.EOS_IDX
            stop_w = vocab.SOS_IDX

        hypotheses = torch.full(
            (1, max_len + 1),
            fill_value=vocab.PAD_IDX,
            dtype=torch.long,
            device=self.device,
        )
        hypotheses[:, 0] = start_w

        hyp_scores = torch.zeros(1, dtype=torch.float, device=self.device)
        completed_hypotheses: List[Hypothesis] = []

        t = 0
        while len(completed_hypotheses) < beam_size and t < max_len:
            hyp_num = hypotheses.size(0)
            assert hyp_num <= beam_size, f"hyp_num: {hyp_num}, beam_size: {beam_size}"

            exp_src = repeat(src.squeeze(0), "s e -> b s e", b=hyp_num)
            exp_mask = repeat(mask.squeeze(0), "s -> b s", b=hyp_num)

            decode_outputs = self(exp_src, exp_mask, hypotheses)[:, t, :]
            log_p_t = F.log_softmax(decode_outputs, dim=-1)

            live_hyp_num = beam_size - len(completed_hypotheses)
            exp_hyp_scores = repeat(hyp_scores, "b -> b e", e=vocab_size)
            continuous_hyp_scores = rearrange(exp_hyp_scores + log_p_t,
                                              "b e -> (b e)")
            top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(
                continuous_hyp_scores, k=live_hyp_num)

            prev_hyp_ids = top_cand_hyp_pos // vocab_size
            hyp_word_ids = top_cand_hyp_pos % vocab_size

            t += 1
            new_hypotheses = []
            new_hyp_scores = []

            for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(
                    prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores):
                cand_new_hyp_score = cand_new_hyp_score.detach().item()
                hypotheses[prev_hyp_id, t] = hyp_word_id

                if hyp_word_id == stop_w:
                    completed_hypotheses.append(
                        Hypothesis(
                            seq_tensor=hypotheses[prev_hyp_id, 1:t].detach().
                            clone(),  # remove START_W at first
                            score=cand_new_hyp_score,
                            direction=direction,
                        ))
                else:
                    new_hypotheses.append(
                        hypotheses[prev_hyp_id].detach().clone())
                    new_hyp_scores.append(cand_new_hyp_score)

            if len(completed_hypotheses) == beam_size:
                break

            hypotheses = torch.stack(new_hypotheses, dim=0)
            hyp_scores = torch.tensor(new_hyp_scores,
                                      dtype=torch.float,
                                      device=self.device)

        if len(completed_hypotheses) == 0:
            completed_hypotheses.append(
                Hypothesis(
                    seq_tensor=hypotheses[0, 1:].detach().clone(),
                    score=hyp_scores[0].detach().item(),
                    direction=direction,
                ))

        return completed_hypotheses
Пример #29
0
def evaluate(checkpoint_path, eval_dir, database_json):
    model = SimpleMaskEstimator(513)

    model.load_checkpoint(
        checkpoint_path=checkpoint_path,
        in_checkpoint_path='model',
        consider_mpi=True
    )
    model.eval()
    if dlp_mpi.IS_MASTER:
        print(f'Start to evaluate the checkpoint {checkpoint_path.resolve()} '
              f'and will write the evaluation result to'
              f' {eval_dir / "result.json"}')
    database = JsonDatabase(database_json)
    test_dataset = get_test_dataset(database)
    with torch.no_grad():
        summary = dict(masked=dict(), beamformed=dict(), observed=dict())
        for batch in dlp_mpi.split_managed(
                test_dataset, is_indexable=True,
                progress_bar=True,
                allow_single_worker=True
        ):
            model_output = model(pt.data.example_to_device(batch))

            example_id = batch['example_id']
            s = batch['speech_source'][0][None]

            speech_mask = model_output['speech_mask_prediction'].numpy()
            Y = batch['observation_stft']
            Z_mask = speech_mask[0] * Y[0]
            z_mask = pb.transform.istft(Z_mask)[None]

            speech_mask = np.median(speech_mask, axis=0).T
            noise_mask = model_output['noise_mask_prediction'].numpy()
            noise_mask = np.median(noise_mask, axis=0).T
            Y = rearrange(Y, 'c t f -> f c t')
            target_psd = pb_bss.extraction.get_power_spectral_density_matrix(
                Y, speech_mask,
            )
            noise_psd = pb_bss.extraction.get_power_spectral_density_matrix(
                Y, noise_mask,
            )
            beamformer = pb_bss.extraction.get_bf_vector(
                'mvdr_souden',
                target_psd_matrix=target_psd,
                noise_psd_matrix=noise_psd

            )
            Z_bf = pb_bss.extraction.apply_beamforming_vector(beamformer, Y).T
            z_bf = pb.transform.istft(Z_bf)[None]

            y = batch['observation'][0][None]
            s = s[:, :z_bf.shape[1]]
            for key, signal in zip(summary.keys(), [z_mask, z_bf, y]):
                signal = signal[:, :s.shape[1]]
                entry = pb_bss.evaluation.OutputMetrics(
                    speech_prediction=signal, speech_source=s,
                    sample_rate=16000
                ).as_dict()
                entry.pop('mir_eval_selection')
                summary[key][example_id] = entry

    summary_list = dlp_mpi.COMM.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        print(f'\n len(summary_list): {len(summary_list)}')
        summary = dict(masked=dict(), beamformed=dict(), observed=dict())
        for partial_summary in summary_list:
            for signal_type, metric in partial_summary.items():
                summary[signal_type].update(metric)
        for signal_type, values in summary.items():
            print(signal_type)
            for metric in next(iter(values.values())).keys():
                mean = np.mean([value[metric] for key, value in values.items()
                                if '_mean' not in key])
                values[metric + '_mean'] = mean
                print(f'{metric}: {mean}')

        result_json_path = eval_dir / 'result.json'
        print(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)
Пример #30
0
def train(train_loader, valid_loader, model, optim, criterion, num_epochs):
    print_every = 5
    model.train()

    lowest_val = 1e9
    train_losses = []
    val_losses = []
    total_step = 0
    print("-" * 100)
    print("Starting Training")
    print("-" * 100)
    for epoch in range(num_epochs):
        pbar = tqdm(total=print_every, leave=False)
        total_loss = 0

        for step, data_dict in enumerate(iter(train_loader)):
            total_step += 1
            src, tgt, src_key_padding_mask, tgt_key_padding_mask = (
                data_dict["ids1"],
                data_dict["ids2"],
                data_dict["masks_sent1"],
                data_dict["masks_sent2"],
            )
            src, src_key_padding_mask = src.to("cpu"), src_key_padding_mask.to(
                "cpu")
            tgt, tgt_key_padding_mask = tgt.to("cpu"), tgt_key_padding_mask.to(
                "cpu")

            memory_key_padding_mask = src_key_padding_mask.clone()
            tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:]
            tgt_mask = gen_nopeek_mask(tgt_inp.shape[1]).to("cpu")

            optim.zero_grad()
            outputs = model(
                src,
                tgt_inp,
                src_key_padding_mask,
                tgt_key_padding_mask[:, :-1],
                memory_key_padding_mask,
                tgt_mask,
            )
            loss = criterion(
                rearrange(outputs, "b t v -> (b t) v"),
                rearrange(tgt_out, "b o -> (b o)"),
            )

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            train_losses.append((step, loss.item()))
            pbar.update(1)
            if step % print_every == print_every - 1:
                pbar.close()
                print(
                    f"Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(train_loader)}] \t "
                    f"Train Loss: {total_loss / print_every}")
                total_loss = 0
                pbar = tqdm(total=print_every, leave=False)
        pbar.close()
        val_loss = validate(valid_loader, model, criterion)
        val_losses.append((total_step, val_loss))
        if val_loss < lowest_val:
            lowest_val = val_loss
            torch.save(model, "output/transformer.pth")
        print(f"Val Loss: {val_loss}")
    return train_losses, val_losses