def single_decoder_model(h, n_out=3, config=None, activation="relu", upsample_config="subpixel", coords=False): if type(upsample_config) is str: # convert string to list of strings, for each upsampling block # upsample config is 1 shorter than config upsample_config = [upsample_config] * (len(config) - 1) assert len(upsample_config) == (len(config) - 1) with nn.model_arg_scope(activation=activation, coords=coords): h = nn.nin(h, 4 * 4 * config[-1]) h = tf.reshape(h, [-1, 4, 4, config[-1]]) h = nn.conv2d(h, config[-1]) h = nn.residual_block(h) for nf, u_method in zip(config[-2::-1], upsample_config[-1::-1]): h = nn.residual_block(h) h = nn.upsample(h, nf, method=u_method) h = nn.residual_block(h) h = nn.conv2d(h, n_out) return h
def decoder_model(h1, h2, config, extra_resnets, activation="relu", coords=False): with nn.model_arg_scope(activation=activation, coords=coords): h1 = nn.nin(h1, 4 * 4 * config[-1]) h1 = tf.reshape(h1, [-1, 4, 4, config[-1]]) h2 = nn.nin(h2, 4 * 4 * config[-1]) h2 = tf.reshape(h2, [-1, 4, 4, config[-1]]) for _ in range(extra_resnets): h1 = nn.residual_block(h1) h2 = nn.residual_block(h2) h = tf.concat([h1, h2], axis=-1) h = nn.conv2d(h, config[-1]) for nf in config[-2::-1]: h = nn.residual_block(h) h = nn.upsample(h, nf) h = nn.residual_block(h) h = nn.conv2d(h, 3) return h
def hourglass_model( x, config, extra_resnets, alpha=None, pi=None, n_out=3, activation="relu", upsample_method="subpixel", coords=False, ): alpha = None pi = None with nn.model_arg_scope(activation=activation, coords=coords): hs = list() h = nn.conv2d(x, config[0]) h = nn.residual_block(h) for nf in config[1:]: h = nn.downsample(h, nf) h = nn.residual_block(h) hs.append(h) for _ in range(extra_resnets): h = nn.residual_block(h) extras = [] if alpha is not None: ha = nn.nin(alpha, 4 * 4 * config[-1]) ha = tf.reshape(ha, [-1, 4, 4, config[-1]]) for _ in range(extra_resnets): ha = nn.residual_block(ha) extras.append(ha) if pi is not None: hp = nn.nin(pi, 4 * 4 * config[-1]) hp = tf.reshape(hp, [-1, 4, 4, config[-1]]) for _ in range(extra_resnets): hp = nn.residual_block(hp) extras.append(hp) if extras: h = tf.concat([h] + extras, axis=-1) h = nn.conv2d(h, config[-1]) for i, nf in enumerate(config[-2::-1]): h = nn.residual_block(h, skipin=hs[-(i + 1)]) h = nn.upsample(h, nf, method=upsample_method) h = nn.residual_block(h) h = nn.conv2d(h, n_out) return h
def single_decoder_model(h): with nn.model_arg_scope(activation="relu"): h = nn.nin(h, 4 * 4 * convconf[-1]) h = tf.reshape(h, [-1, 4, 4, convconf[-1]]) h = nn.conv2d(h, convconf[-1]) h = nn.residual_block(h) for nf in convconf[-2::-1]: h = nn.residual_block(h) h = nn.upsample(h, nf) h = nn.residual_block(h) h = nn.conv2d(h, 3) return h
def single_decoder_model(h, n_out=3, config=None, activation="relu", upsample_method="subpixel"): with nn.model_arg_scope(activation=activation): h = nn.nin(h, 4 * 4 * config[-1]) h = tf.reshape(h, [-1, 4, 4, config[-1]]) h = nn.conv2d(h, config[-1]) h = nn.residual_block(h) for nf in config[-2::-1]: h = nn.residual_block(h) h = nn.upsample(h, nf, method=upsample_method) h = nn.residual_block(h) h = nn.conv2d(h, n_out) return h