Exemplo n.º 1
0
def cost_volume_to_flow(cvol: tf.Tensor, data_format: str = None):
    """
    Predict optical flow from cost volume by finding the argmax of correlation.
    """
    if data_format is None:
        data_format = tf.keras.backend.image_data_format()
    if data_format == 'channels_last':
        axis = -1
        dims = einops.parse_shape(cvol, '... c')['c']
    else:
        axis = -3
        # dims = einops.parse_shape(cvol, '... c _ _')['c']
        dims = einops.parse_shape(cvol, 'c _ _')['c']
    imax = tf.argmax(cvol, axis=axis)

    # unravel_index
    imax = tf.cast(imax, tf.float32)
    q = tf.sqrt(tf.cast(dims, tf.float32))
    di = tf.floor(imax / q)
    dj = imax - di * q

    # delta from center
    di = di - (q - 1) / 2
    dj = dj - (q - 1) / 2

    return tf.stack([di, dj], axis=axis)
Exemplo n.º 2
0
    def call(self, y_true, y_pred):
        if self.data_format == 'channels_first':
            pattern = 'n c h w'
        else:
            pattern = 'n h w c'
        true_shape = einops.parse_shape(y_true, pattern)
        pred_shape = einops.parse_shape(y_pred, pattern)

        # Scale by which to multiply flow magnitude
        flow_scale = pred_shape['h'] / true_shape['h']

        # Scale by which to multiply loss magnitude
        loss_scale = 1.0 / (pred_shape['w'] * pred_shape['h'])

        # NOTE(ycho): In general, will be integer multiples,
        # so we use einops.reduce() instead of e.g. resize_bilinear()
        if self.data_format == 'channels_first':
            y_pred = einops.rearrange(y_pred, 'n c h w -> n h w c')
            y_true = flow_scale * einops.reduce(y_true,
                                                'n c (h sh) (w sw) -> n h w c',
                                                'mean', sh=true_shape['h'] //
                                                pred_shape['h'],
                                                sw=true_shape['w'] //
                                                pred_shape['w'])
        else:
            y_true = flow_scale * einops.reduce(y_true,
                                                'n (h sh) (w sw) c -> n h w c',
                                                'mean', sh=true_shape['h'] //
                                                pred_shape['h'],
                                                sw=true_shape['w'] //
                                                pred_shape['w'])

        # finally, we call the loss...
        loss = self.loss(loss_scale * (y_true - y_pred))
        return loss
Exemplo n.º 3
0
    def test6(x):
        # parsing parameters
        t = rearrange(x, 'b c h w -> (b h w) c')
        t = t[:, ::2]  # replacement for dot-product, just changes size of second axis
        assert t.shape == (10 * 30 * 40, 10)

        y = rearrange(t, '(b h w) c2 -> b c2 h w', **parse_shape(x, 'b _ h w'))
        assert y.shape == (10, 10, 30, 40)
        return y
Exemplo n.º 4
0
 def forward(self, x):
     proj_query = rearrange(self.query_conv(x), 'b c h w -> b (h w) c')
     proj_key = rearrange(self.key_conv(x), 'b c h w -> b c (h w)')
     proj_value = rearrange(self.value_conv(x), 'b c h w -> b (h w) c')
     energy = torch.bmm(proj_query, proj_key)
     attention = F.softmax(energy, dim=2)
     out = torch.bmm(attention, proj_value)
     out = x + self.gamma * rearrange(out, 'b (h w) c -> b c h w',
                                      **parse_shape(x, 'b c h w'))
     return out, attention
Exemplo n.º 5
0
    def call(self, y_true, y_pred):
        if self.data_format == 'channels_first':
            pattern = '_ c h w'
        else:
            pattern = '_ h w c'
        true_shape = einops.parse_shape(y_true, pattern)
        pred_shape = einops.parse_shape(y_pred, pattern)

        # Scale by which to multiply flow magnitude
        flow_scale = pred_shape['h'] / true_shape['h']

        # Scale by which to multiply loss magnitude
        loss_scale = 1.0 / (pred_shape['w'] * pred_shape['h'])

        # NOTE(ycho): In general, will be integer multiples,
        # so we use einops.reduce() instead of e.g. resize_bilinear()
        if self.data_format == 'channels_first':
            y_true = flow_scale * einops.reduce(y_true,
                                                'n c (h sh) (w sw) -> n c h w',
                                                'mean', sh=true_shape['h'] //
                                                pred_shape['h'],
                                                sw=true_shape['w'] //
                                                pred_shape['w'])
        else:
            y_true = flow_scale * einops.reduce(y_true,
                                                'n (h sh) (w sw) c -> n h w c',
                                                'mean', sh=true_shape['h'] //
                                                pred_shape['h'],
                                                sw=true_shape['w'] //
                                                pred_shape['w'])

        # Finally, we call the loss...
        raw_loss = loss_scale * (y_true - y_pred)

        # NOTE(ycho): Treat ALF loss distribution over the flow channels.
        if self.data_format == 'channels_first':
            raw_loss = einops.rearrange(raw_loss, 'n c h w -> (n h w) c')
        else:
            raw_loss = einops.rearrange(raw_loss, 'n h w c -> (n h w) c')
        loss = tf.reduce_mean(self.loss_func(raw_loss))
        return loss
Exemplo n.º 6
0
def parse_image_shape(tensor: tf.Tensor,
                      data_format=None,
                      blocklist: List[str] = []):
    if data_format is None:
        data_format = tf.keras.backend.image_data_format()
    if data_format is 'channels_first':
        pattern = 'n c h w'
    else:
        pattern = 'n h w c'

    for block in blocklist:
        pattern.replace(block, '_')
    return einops.parse_shape(tensor, pattern)
Exemplo n.º 7
0
def get_spatial_shape(x: tf.Tensor, data_format: str = None):
    # 3D pattern
    if data_format == 'channels_first':
        pattern = '_ h w'
        axis = -3
    else:
        pattern = 'h w _'
        axis = -1

    # Batch dimension
    if tf.rank(x) >= 4:
        is_batch = True
    if is_batch:
        pattern = 'n ' + pattern

    return einops.parse_shape(x, pattern)
Exemplo n.º 8
0
def tf_warp(img, flow, data_format=None):
    if data_format is None:
        data_format = tf.keras.backend.image_data_format()
    # 3D pattern
    if data_format == 'channels_first':
        pattern = '_ h w'
        axis = -3
    else:
        pattern = 'h w _'
        axis = -1

    # Check if batched ...
    if tf.rank(flow) >= 4:
        print('is_batch')
        is_batch = True

    if is_batch:
        pattern = '_ ' + pattern

    shape = einops.parse_shape(img, pattern)
    W, H = shape['w'], shape['h']
    print('w={}, h={}'.format(W, H))

    # Compute grid coordinates.
    x, y = tf.meshgrid(tf.range(W), tf.range(H))

    # Add channel dims
    x = tf.expand_dims(x, axis=axis)
    y = tf.expand_dims(y, axis=axis)

    # Add batch dims
    if is_batch:
        x = tf.expand_dims(x, axis=0)
        y = tf.expand_dims(y, axis=0)

    x = tf.cast(x, tf.float32)
    y = tf.cast(y, tf.float32)
    grid = tf.concat([x, y], axis=axis)

    flows = grid + flow
    max_y = tf.cast(H - 1, tf.int32)
    max_x = tf.cast(W - 1, tf.int32)
    zero = tf.constant(0, dtype=tf.int32)

    # Deal with individual components
    if data_format == 'channels_first':
        x, y = einops.rearrange(flows, 'n c h w -> c n h w', c=2)
    else:
        x, y = einops.rearrange(flows, 'n h w c -> c n h w', c=2)

    x0 = x
    y0 = y
    x0 = tf.cast(x0, tf.int32)
    x1 = x0 + 1
    y0 = tf.cast(y0, tf.int32)
    y1 = y0 + 1

    # clip to range [0, H/W] to not violate img boundaries
    x0 = tf.clip_by_value(x0, zero, max_x)
    x1 = tf.clip_by_value(x1, zero, max_x)
    y0 = tf.clip_by_value(y0, zero, max_y)
    y1 = tf.clip_by_value(y1, zero, max_y)

    # get pixel value at corner coords
    Ia = get_pixel_value(img, x0, y0, data_format)
    Ib = get_pixel_value(img, x0, y1, data_format)
    Ic = get_pixel_value(img, x1, y0, data_format)
    Id = get_pixel_value(img, x1, y1, data_format)

    # recast as float for delta calculation
    x0 = tf.cast(x0, tf.float32)
    x1 = tf.cast(x1, tf.float32)
    y0 = tf.cast(y0, tf.float32)
    y1 = tf.cast(y1, tf.float32)

    # calculate deltas
    wa = (x1 - x) * (y1 - y)
    wb = (x1 - x) * (y - y0)
    wc = (x - x0) * (y1 - y)
    wd = (x - x0) * (y - y0)

    # add dimension for addition
    wa = tf.expand_dims(wa, axis=axis)
    wb = tf.expand_dims(wb, axis=axis)
    wc = tf.expand_dims(wc, axis=axis)
    wd = tf.expand_dims(wd, axis=axis)

    # compute output
    out = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])

    return out