def forward(self, src_input, tgt_input, accum_field=None,
             _index=slice(None), **kwargs):
     prev_level = None
     for i, aligner in zip(reversed(range(len(self.list))[_index]),
                           reversed(self.list[_index])):
         if isinstance(src_input, list) and isinstance(tgt_input, list):
             src, tgt = src_input[i], tgt_input[i]
         else:
             src, tgt = downsample(i)(src_input), downsample(i)(tgt_input)
         if prev_level is not None:
             accum_field = (upsample(prev_level - i)
                            (accum_field.permute(0, 3, 1, 2))
                            .permute(0, 2, 3, 1))
             src = gridsample_residual(src, accum_field,
                                       padding_mode='border')
         factor = 2 / src.shape[-1]  # scale to [-1,1]
         res_field = aligner(src, tgt, **kwargs) * factor
         if accum_field is not None:
             resampled = gridsample_residual(
                 accum_field.permute(0, 3, 1, 2), res_field,
                 padding_mode='border').permute(0, 2, 3, 1)
             accum_field = res_field + resampled
         else:
             accum_field = res_field
         prev_level = i
     accum_field = (upsample(prev_level)
                    (accum_field.permute(0, 3, 1, 2))
                    .permute(0, 2, 3, 1))
     return accum_field
    def forward(self, stack, target_level, vis=None, use_preencoder=False):
        if vis is not None:
            gif(vis + 'input', gif_prep(stack))

        if use_preencoder:
            # run the preencoder
            residual = self.pe(stack)
            stack = stack + residual
            if vis is not None:
                # visualize the preencoder output
                zm = (stack == 0).data
                print('residual me,mi,ma {},{},{}'.format(
                    torch.mean(residual[~zm]).data[0],
                    torch.min(residual[~zm]).data[0],
                    torch.max(residual[~zm]).data[0]))
                gif(vis + 'pre_enc_residual', gif_prep(residual))
                gif(vis + 'pre_enc_output', gif_prep(stack))
            if use_preencoder == "only":
                # only run the preencoder and return the results
                return stack

        encodings = [self.enclist[0](stack)]
        for idx in range(1, self.size - self.topskips):
            encodings.append(self.enclist[idx](self.down(encodings[-1]),
                                               vis=vis))

        rdim = stack.shape[-2] // (2**(self.size - 1 - self.topskips))
        field_so_far = torch.zeros((1, rdim, rdim, 2),
                                   device=encodings[0].device)  # zero field
        residuals = []
        for i in range(self.size - 1 - self.topskips, target_level - 1, -1):
            if i >= self.skip:
                inputs_i = encodings[i]
                resampled_source = gridsample_residual(
                    inputs_i[:, 0:inputs_i.size(1) // 2],
                    field_so_far,
                    padding_mode='zeros')
                new_input_i = torch.cat(
                    (resampled_source, inputs_i[:, inputs_i.size(1) // 2:]), 1)
                factor = (self.TRAIN_SIZE / (2.**i)) / new_input_i.size()[-1]
                rfield = self.mlist[i](new_input_i) * factor
                residuals.append(rfield)
                # Resample field_so_far using rfield. Add rfield to the result
                # to produce the new field_so_far.
                resampled_field_so_far = gridsample_residual(
                    field_so_far.permute(0, 3, 1, 2),
                    rfield,
                    padding_mode='border').permute(0, 2, 3, 1)
                field_so_far = rfield + resampled_field_so_far
            if i != target_level:
                field_so_far = self.up(field_so_far.permute(0, 3, 1,
                                                            2)).permute(
                                                                0, 2, 3, 1)
        return field_so_far, residuals
Exemple #3
0
    def forward(self, src, tgt, prediction, masks=None):
        if masks is None or masks.nelement() == 0:
            masks = gen_masks(src, tgt, prediction)
        else:
            masks = prepare_masks(src, tgt, masks)
        src_masks = masks['src_masks']
        tgt_masks = masks['tgt_masks']
        src_field_masks = masks['src_field_masks']
        tgt_field_masks = masks['tgt_field_masks']

        src, tgt = src.to(prediction.device), tgt.to(prediction.device)

        src_warped = gridsample_residual(src, prediction, padding_mode='zeros')
        image_loss_map = (src_warped - tgt)**2
        if src_masks or tgt_masks:
            image_weights = torch.ones_like(image_loss_map)
            if src_masks is not None:
                for mask in src_masks:
                    mask = gridsample_residual(mask, prediction,
                                               padding_mode='border')
                    image_loss_map = image_loss_map * mask
                    image_weights = image_weights * mask
            if tgt_masks is not None:
                for mask in tgt_masks:
                    image_loss_map = image_loss_map * mask
                    image_weights = image_weights * mask
            mse_loss = image_loss_map.sum() / image_weights.sum()
        else:
            mse_loss = image_loss_map.mean()

        field_loss_map = self.field_penalty([prediction])
        if src_field_masks or tgt_field_masks:
            field_weights = torch.ones_like(field_loss_map)
            if src_field_masks is not None:
                for mask in src_field_masks:
                    mask = gridsample_residual(mask, prediction,
                                               padding_mode='border')
                    field_loss_map = field_loss_map * mask
                    field_weights = field_weights * mask
            if tgt_field_masks is not None:
                for mask in tgt_field_masks:
                    field_loss_map = field_loss_map * mask
                    field_weights = field_weights * mask
            field_loss = field_loss_map.sum() / field_weights.sum()
        else:
            field_loss = field_loss_map.mean()

        loss = (mse_loss + self.lambda1 * field_loss)
        return loss
 def forward(self, src_input, tgt_input, accum_field=None):
     for i in reversed(range(self.height)):
         if isinstance(src_input, list) and isinstance(tgt_input, list):
             src, tgt = src_input[i], tgt_input[i]
         else:
             src, tgt = downsample(i)(src_input), downsample(i)(tgt_input)
         if accum_field is not None:
             accum_field = (upsample()(accum_field.permute(0, 3, 1,
                                                           2)).permute(
                                                               0, 2, 3, 1))
             src = gridsample_residual(src,
                                       accum_field,
                                       padding_mode='border')
         factor = 2 / src.shape[-1]  # scale to [-1,1]
         res_field = self.list[i](src, tgt) * factor
         if accum_field is not None:
             resampled = gridsample_residual(
                 accum_field.permute(0, 3, 1, 2),
                 res_field,
                 padding_mode='border').permute(0, 2, 3, 1)
             accum_field = res_field + resampled
         else:
             accum_field = res_field
     return accum_field