def build_flower( train=True, input_shape: Tuple[int, int] = (256, 512), data_format=None, use_tfa: bool = True, ) -> tf.keras.Model: if data_format is None: data_format = tf.keras.backend.image_data_format() # Input if data_format == 'channels_first': inputs = tf.keras.Input(shape=(6, ) + input_shape, dtype=tf.float32, name='inputs') else: inputs = tf.keras.Input(shape=input_shape + (6, ), dtype=tf.float32, name='inputs') # Split input. axis = _get_axis(data_format) img_prv, img_nxt = Split(2, axis=axis)(inputs) # hmm... encs_prv, encs_nxt = encoder(img_prv, img_nxt, True, train=True) decs_prv, decs_nxt = decoder(encs_prv, encs_nxt, True, train=True) outputs = flower(encs_prv[-1], encs_nxt[-1], decs_prv, decs_nxt, output_multiscale=train, use_tfa=use_tfa) model = tf.keras.Model(inputs=inputs, outputs=outputs, name='qpwc_net') return model
def __init__(self, model: tf.keras.Model, clip_factor: float = 0.01, eps: float = 1e-3): super().__init__(model.inputs, model.outputs) self.model = model self.clip_factor = clip_factor self.eps = eps data_format = tf.keras.backend.image_data_format() self.axis = _get_axis(data_format)
def build_interpolator(input_shape: Tuple[int, int], data_format=None, use_tfa: bool = True, *args, **kwargs): # input if data_format is None: data_format = tf.keras.backend.image_data_format() if data_format == 'channels_first': inputs = tf.keras.Input(shape=(6, ) + input_shape, dtype=tf.float32, name='inputs') else: inputs = tf.keras.Input(shape=input_shape + (6, ), dtype=tf.float32, name='inputs') # Split input. axis = _get_axis(data_format) img_prv, img_nxt = Split(2, axis=axis)(inputs) encs_prv, encs_nxt = encoder(img_prv, img_nxt, True) decs_prv, decs_nxt = decoder(encs_prv, encs_nxt, True) flower_block = Flower(len(decs_prv), output_multiscale=True, use_tfa=use_tfa) flows_01 = flower_block((encs_nxt[-1], encs_prv[-1], decs_nxt, decs_prv)) # ^^^^ ALL code blocks above must EXACTLY match # build_flower() in order for the transfer to work. # this is because we are not very meticulous about # bookkeeping layer correspondences. flows_10 = flower_block((encs_prv[-1], encs_nxt[-1], decs_prv, decs_nxt)) outputs = interpolator(img_prv, img_nxt, decs_prv, decs_nxt, flows_01, flows_10, use_tfa=use_tfa, *args, **kwargs) return tf.keras.Model(inputs=inputs, outputs=outputs, name='qpwc_net')
def flower(enc_prv, enc_nxt, decs_prv, decs_nxt, output_multiscale: bool = True, use_tfa: bool = True): """ Frame interpolation stack. """ data_format = tf.keras.backend.image_data_format() axis = _get_axis(data_format) # feature axis # How many encoding/decoding layers? n = len(decs_prv) # flo_01 = fwd, i.e. warp(nxt,flo_01)==prv flow = Flow(use_tfa=use_tfa) flo_01 = flow((enc_prv, enc_nxt)) flos = [flo_01] for i in range(n): # Get inputs at current layer ... dec_prv = decs_prv[i] dec_nxt = decs_nxt[i] # Create layers at the current level. upsample = Upsample(scale=2.0) upflow = UpFlow(use_tfa=use_tfa) # Compute current stage motion block. # previous motion block + network features # NOTE(ycho): Unlike typical upsampling, also mulx2 flo_01_u = upsample(flo_01) flo_01 = upflow((dec_prv, dec_nxt, flo_01_u)) flos.append(flo_01) # Final full-res flow is ONLY upsampled. flo_01 = Upsample(scale=2.0)(flo_01) flos.append(flo_01) if output_multiscale: outputs = flos else: outputs = [flo_01] return outputs
def decoder(encs_prv, encs_nxt, use_skip: bool = True, train: bool = True): data_format = tf.keras.backend.image_data_format() axis = _get_axis(data_format) # print('axis={}'.format(axis)) # 1? # build layers = [] for num_filters in [128, 64, 32, 16]: # NOTE(ycho): does not include layer of equal size as input conv = UpConv(num_filters) if not train: conv.trainable = False layers.append(conv) # apply/prv f = encs_prv[-1] i = -2 decs_prv = [] for l in layers: f = l(f) if use_skip: f = tf.concat([f, encs_prv[i]], axis=axis) i -= 1 decs_prv.append(f) # apply/nxt f = encs_nxt[-1] i = -2 decs_nxt = [] for l in layers: f = l(f) if use_skip: f = tf.concat([f, encs_nxt[i]], axis=axis) i -= 1 decs_nxt.append(f) return (decs_prv, decs_nxt)
def interpolator(img_prv, img_nxt, decs_prv, decs_nxt, flos_01, flos_10, output_multiscale: bool = True, use_tfa: bool = True): """ Frame interpolation stack. """ data_format = tf.keras.backend.image_data_format() axis = _get_axis(data_format) # feature axis # How many encoding/decoding layers? n = len(decs_prv) # Create downsampled image pyramid. # Expect feats_prv/feats_nxt to be outputs from the `encoder(...)`. # This means feats_prv[0] == img_prv, etc. imgs_prv = [img_prv] imgs_nxt = [img_nxt] for i in range(n + 1): pool = Downsample() imgs_prv.append(pool(imgs_prv[-1])) imgs_nxt.append(pool(imgs_nxt[-1])) # Now, we have a downsampled image - # each of the same size as a corresponding feature layer. # Create flow layers pyramid. # Middle Frame is created from the composition of # the two images and the flow. # flo_01 = fwd, i.e. warp(nxt,flo_01)==prv # flo_10 = bwd, i.e. warp(prv,flo_10)==nxt # flo_10... path is only used(needed) for interpolator. img = FrameInterpolate(up=False, name='img_0')( (imgs_prv[-1], imgs_nxt[-1], flos_01[0], flos_10[0])) imgs = [img] # n-2 means skip the last one (for which we explicitly construct flow/img from scratch). # 0 means skip the final layer (and only apply upsampling). for i in range(n): # Get inputs at current layer ... dec_prv = decs_prv[i] dec_nxt = decs_nxt[i] img_prv = imgs_prv[n - i] img_nxt = imgs_nxt[n - i] # Create layers at the current level. upsample = Upsample(scale=2.0) # Upsampled previous image + motion block + # downsampled input images img_u = Upsample(scale=1.0)(img) img = FrameInterpolate(up=True, name='img_{}'.format(i + 1))( (dec_prv, dec_nxt, flos_01[i + 1], flos_10[i + 1], img_u)) imgs.append(img) # Final full-res img is ONLY upsampled. img = Upsample(scale=1.0, name='img_{}'.format(n + 1))(img) imgs.append(img) if output_multiscale: return imgs # o.w. only return last img. return imgs[-1]