Ejemplo n.º 1
0
def build_flower(
    train=True,
    input_shape: Tuple[int, int] = (256, 512),
    data_format=None,
    use_tfa: bool = True,
) -> tf.keras.Model:
    if data_format is None:
        data_format = tf.keras.backend.image_data_format()
    # Input
    if data_format == 'channels_first':
        inputs = tf.keras.Input(shape=(6, ) + input_shape,
                                dtype=tf.float32,
                                name='inputs')
    else:
        inputs = tf.keras.Input(shape=input_shape + (6, ),
                                dtype=tf.float32,
                                name='inputs')

    # Split input.
    axis = _get_axis(data_format)
    img_prv, img_nxt = Split(2, axis=axis)(inputs)

    # hmm...
    encs_prv, encs_nxt = encoder(img_prv, img_nxt, True, train=True)
    decs_prv, decs_nxt = decoder(encs_prv, encs_nxt, True, train=True)

    outputs = flower(encs_prv[-1],
                     encs_nxt[-1],
                     decs_prv,
                     decs_nxt,
                     output_multiscale=train,
                     use_tfa=use_tfa)

    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='qpwc_net')
    return model
Ejemplo n.º 2
0
 def __init__(self,
              model: tf.keras.Model,
              clip_factor: float = 0.01,
              eps: float = 1e-3):
     super().__init__(model.inputs, model.outputs)
     self.model = model
     self.clip_factor = clip_factor
     self.eps = eps
     data_format = tf.keras.backend.image_data_format()
     self.axis = _get_axis(data_format)
Ejemplo n.º 3
0
def build_interpolator(input_shape: Tuple[int, int],
                       data_format=None,
                       use_tfa: bool = True,
                       *args,
                       **kwargs):
    # input
    if data_format is None:
        data_format = tf.keras.backend.image_data_format()
    if data_format == 'channels_first':
        inputs = tf.keras.Input(shape=(6, ) + input_shape,
                                dtype=tf.float32,
                                name='inputs')
    else:
        inputs = tf.keras.Input(shape=input_shape + (6, ),
                                dtype=tf.float32,
                                name='inputs')

    # Split input.
    axis = _get_axis(data_format)
    img_prv, img_nxt = Split(2, axis=axis)(inputs)

    encs_prv, encs_nxt = encoder(img_prv, img_nxt, True)
    decs_prv, decs_nxt = decoder(encs_prv, encs_nxt, True)
    flower_block = Flower(len(decs_prv),
                          output_multiscale=True,
                          use_tfa=use_tfa)
    flows_01 = flower_block((encs_nxt[-1], encs_prv[-1], decs_nxt, decs_prv))

    # ^^^^ ALL code blocks above must EXACTLY match
    # build_flower() in order for the transfer to work.
    # this is because we are not very meticulous about
    # bookkeeping layer correspondences.
    flows_10 = flower_block((encs_prv[-1], encs_nxt[-1], decs_prv, decs_nxt))

    outputs = interpolator(img_prv,
                           img_nxt,
                           decs_prv,
                           decs_nxt,
                           flows_01,
                           flows_10,
                           use_tfa=use_tfa,
                           *args,
                           **kwargs)
    return tf.keras.Model(inputs=inputs, outputs=outputs, name='qpwc_net')
Ejemplo n.º 4
0
def flower(enc_prv,
           enc_nxt,
           decs_prv,
           decs_nxt,
           output_multiscale: bool = True,
           use_tfa: bool = True):
    """ Frame interpolation stack. """
    data_format = tf.keras.backend.image_data_format()
    axis = _get_axis(data_format)  # feature axis

    # How many encoding/decoding layers?
    n = len(decs_prv)

    # flo_01 = fwd, i.e. warp(nxt,flo_01)==prv
    flow = Flow(use_tfa=use_tfa)
    flo_01 = flow((enc_prv, enc_nxt))
    flos = [flo_01]

    for i in range(n):
        # Get inputs at current layer ...
        dec_prv = decs_prv[i]
        dec_nxt = decs_nxt[i]

        # Create layers at the current level.
        upsample = Upsample(scale=2.0)
        upflow = UpFlow(use_tfa=use_tfa)

        # Compute current stage motion block.
        # previous motion block + network features
        # NOTE(ycho): Unlike typical upsampling, also mulx2
        flo_01_u = upsample(flo_01)
        flo_01 = upflow((dec_prv, dec_nxt, flo_01_u))
        flos.append(flo_01)

    # Final full-res flow is ONLY upsampled.
    flo_01 = Upsample(scale=2.0)(flo_01)
    flos.append(flo_01)

    if output_multiscale:
        outputs = flos
    else:
        outputs = [flo_01]
    return outputs
Ejemplo n.º 5
0
def decoder(encs_prv, encs_nxt, use_skip: bool = True, train: bool = True):
    data_format = tf.keras.backend.image_data_format()
    axis = _get_axis(data_format)
    # print('axis={}'.format(axis)) # 1?

    # build
    layers = []
    for num_filters in [128, 64, 32, 16]:
        # NOTE(ycho): does not include layer of equal size as input
        conv = UpConv(num_filters)
        if not train:
            conv.trainable = False
        layers.append(conv)

    # apply/prv
    f = encs_prv[-1]
    i = -2
    decs_prv = []
    for l in layers:
        f = l(f)
        if use_skip:
            f = tf.concat([f, encs_prv[i]], axis=axis)
            i -= 1
        decs_prv.append(f)

    # apply/nxt
    f = encs_nxt[-1]
    i = -2
    decs_nxt = []
    for l in layers:
        f = l(f)
        if use_skip:
            f = tf.concat([f, encs_nxt[i]], axis=axis)
            i -= 1
        decs_nxt.append(f)
    return (decs_prv, decs_nxt)
Ejemplo n.º 6
0
def interpolator(img_prv,
                 img_nxt,
                 decs_prv,
                 decs_nxt,
                 flos_01,
                 flos_10,
                 output_multiscale: bool = True,
                 use_tfa: bool = True):
    """ Frame interpolation stack. """
    data_format = tf.keras.backend.image_data_format()
    axis = _get_axis(data_format)  # feature axis

    # How many encoding/decoding layers?
    n = len(decs_prv)

    # Create downsampled image pyramid.
    # Expect feats_prv/feats_nxt to be outputs from the `encoder(...)`.
    # This means feats_prv[0] == img_prv, etc.
    imgs_prv = [img_prv]
    imgs_nxt = [img_nxt]
    for i in range(n + 1):
        pool = Downsample()
        imgs_prv.append(pool(imgs_prv[-1]))
        imgs_nxt.append(pool(imgs_nxt[-1]))
    # Now, we have a downsampled image -
    # each of the same size as a corresponding feature layer.

    # Create flow layers pyramid.

    # Middle Frame is created from the composition of
    # the two images and the flow.
    # flo_01 = fwd, i.e. warp(nxt,flo_01)==prv
    # flo_10 = bwd, i.e. warp(prv,flo_10)==nxt
    # flo_10... path is only used(needed) for interpolator.
    img = FrameInterpolate(up=False, name='img_0')(
        (imgs_prv[-1], imgs_nxt[-1], flos_01[0], flos_10[0]))
    imgs = [img]

    # n-2 means skip the last one (for which we explicitly construct flow/img from scratch).
    # 0 means skip the final layer (and only apply upsampling).
    for i in range(n):
        # Get inputs at current layer ...
        dec_prv = decs_prv[i]
        dec_nxt = decs_nxt[i]
        img_prv = imgs_prv[n - i]
        img_nxt = imgs_nxt[n - i]

        # Create layers at the current level.
        upsample = Upsample(scale=2.0)

        # Upsampled previous image + motion block +
        # downsampled input images
        img_u = Upsample(scale=1.0)(img)
        img = FrameInterpolate(up=True, name='img_{}'.format(i + 1))(
            (dec_prv, dec_nxt, flos_01[i + 1], flos_10[i + 1], img_u))
        imgs.append(img)

    # Final full-res img is ONLY upsampled.
    img = Upsample(scale=1.0, name='img_{}'.format(n + 1))(img)
    imgs.append(img)

    if output_multiscale:
        return imgs
    # o.w. only return last img.
    return imgs[-1]