Beispiel #1
0
def dec_down(
        gs, zs_posterior, training, init=False, dropout_p=0.5,
        n_scales=1, n_residual_blocks=2, activation="elu",
        n_latent_scales=2):
    assert n_residual_blocks % 2 == 0
    gs = list(gs)
    zs_posterior = list(zs_posterior)
    with model_arg_scope(
            init=init, dropout_p=dropout_p, activation=activation):
        # outputs
        hs = []  # hidden units
        ps = []  # priors
        zs = []  # prior samples
        # prepare input
        n_filters = gs[-1].shape.as_list()[-1]
        h = nn.nin(gs[-1], n_filters)
        for l in range(n_scales):
            # level module
            ## hidden units
            for i in range(n_residual_blocks // 2):
                h = nn.residual_block(h, gs.pop())
                hs.append(h)
            if l < n_latent_scales:
                ## prior
                spatial_shape = h.shape.as_list()[1]
                n_h_channels = h.shape.as_list()[-1]

                ### no spatial correlations
                p = latent_parameters(h)
                ps.append(p)
                z_prior = latent_sample(p)
                zs.append(z_prior)

                if training:
                    ## posterior
                    z = zs_posterior.pop(0)
                else:
                    ## prior
                    z = z_prior
                for i in range(n_residual_blocks // 2):
                    n_h_channels = h.shape.as_list()[-1]
                    h = tf.concat([h, z], axis=-1)
                    h = nn.nin(h, n_h_channels)
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
            else:
                for i in range(n_residual_blocks // 2):
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = gs[-1].shape.as_list()[-1]
                h = nn.upsample(h, n_filters)

        assert not gs
        if training:
            assert not zs_posterior

        return hs, ps, zs
def enc_down(
        gs, init = False, dropout_p = 0.5,
        n_scales = 1, n_residual_blocks = 2, activation = "elu",
        n_latent_scales = 2):
    assert n_residual_blocks % 2 == 0
    gs = list(gs)
    with model_arg_scope(
            init = init, dropout_p = dropout_p, activation = activation):
        # outputs
        hs = [] # hidden units
        qs = [] # posteriors
        zs = [] # samples from posterior
        # prepare input
        n_filters = gs[-1].shape.as_list()[-1]
        h = nn.nin(gs[-1], n_filters)
        for l in range(n_scales):
            # level module
            ## hidden units
            for i in range(n_residual_blocks // 2):
                h = nn.residual_block(h, gs.pop())
                hs.append(h)
            if l < n_latent_scales:
                ## posterior parameters
                q = latent_parameters(h)
                qs.append(q)
                ## posterior sample
                z = latent_sample(q)
                zs.append(z)
                ## sample feedback
                for i in range(n_residual_blocks // 2):
                    gz = tf.concat([gs.pop(), z], axis = -1)
                    h = nn.residual_block(h, gz)
                    hs.append(h)
            else:
                """ no need to go down any further
                for i in range(n_residual_blocks // 2):
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
                """
                break
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = gs[-1].shape.as_list()[-1]
                h = nn.upsample(h, n_filters)

        #assert not gs # not true anymore since we break out of the loop

        return hs, qs, zs
Beispiel #3
0
def enc_up(x,
           c,
           init=False,
           dropout_p=0.5,
           n_scales=1,
           n_residual_blocks=2,
           activation="elu",
           n_filters=64,
           max_filters=128):
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        # outputs
        hs = []
        # prepare input
        # 这一行也很奇怪, 为什么要把x和c连起来呢?
        # xc = tf.concat([x,c], axis = -1)
        xc = x
        h = nn.nin(xc, n_filters)
        for l in range(n_scales):
            # level module
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                # 似乎它这个channel一直都是128, 没有增长过.
                n_filters = min(2 * n_filters, max_filters)
                h = nn.downsample(h, n_filters)

        return hs
Beispiel #4
0
def dec_up(
    c,
    init=False,
    dropout_p=0.5,
    n_scales=1,
    n_residual_blocks=2,
    activation="elu",
    n_filters=64,
    max_filters=128,
):
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        # outputs
        hs = []
        # prepare input
        h = nn.nin(c, n_filters)
        for l in range(n_scales):
            # level module
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = min(2 * n_filters, max_filters)
                h = nn.downsample(h, n_filters)
        return hs
Beispiel #5
0
def enc_up(
    x,
    c,
    init=False,
    dropout_p=0.5,
    n_scales=1,
    n_residual_blocks=2,
    activation="elu",
    n_filters=64,
    max_filters=128,
):
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        """c is actually not used"""
        # outputs
        hs = []
        # prepare input
        # xc = tf.concat([x,c], axis = -1)
        xc = x
        h = nn.nin(xc, n_filters)
        for l in range(n_scales):
            # level module
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = min(2 * n_filters, max_filters)
                h = nn.downsample(h, n_filters)
        return hs
Beispiel #6
0
def enc_down(gs,
             init=False,
             dropout_p=0.5,
             n_scales=1,
             n_residual_blocks=2,
             activation="elu",
             n_latent_scales=2):
    assert n_residual_blocks % 2 == 0
    gs = list(gs)
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        hs = []  # hidden units
        qs = []  # posteriors
        zs = []  # samples from posterior
        n_filters = gs[-1].shape.as_list()[-1]
        h = nn.nin(gs[-1], n_filters)
        for l in range(n_scales):
            for i in range(n_residual_blocks // 2):
                h = nn.residual_block(h, gs.pop())
                hs.append(h)
            if l < n_latent_scales:
                q = latent_parameters(h)  # posterior parameters
                qs.append(q)
                z = latent_sample(q)  # posterior sample
                zs.append(z)
                for i in range(n_residual_blocks // 2):
                    gz = tf.concat([gs.pop(), z], axis=-1)
                    h = nn.residual_block(h, gz)
                    hs.append(h)
            else:
                break
            if l + 1 < n_scales:
                n_filters = gs[-1].shape.as_list()[-1]
                h = nn.upsample(h, n_filters)

        return hs, qs, zs
Beispiel #7
0
def dec_up(c,
           init=False,
           dropout_p=0.5,
           n_scales=1,
           n_residual_blocks=2,
           activation="elu",
           n_filters=64,
           max_filters=128):
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        hs = []
        h = nn.nin(c, n_filters)
        for l in range(n_scales):
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            if l + 1 < n_scales:
                n_filters = min(2 * n_filters, max_filters)
                h = nn.downsample(h, n_filters)
        return hs
def cfn(
        x, init = False, dropout_p = 0.5,
        n_scales = 1, n_residual_blocks = 2, activation = "elu", n_filters = 64, max_filters = 128):
    with model_arg_scope(
            init = init, dropout_p = dropout_p, activation = activation):
        # outputs
        hs = []
        # prepare input
        xc = x
        h = nn.nin(xc, n_filters)
        for l in range(n_scales):
            # level module
            for i in range(n_residual_blocks):
                h = nn.residual_block(h)
                hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = min(2*n_filters, max_filters)
                h = nn.downsample(h, n_filters)
        h_shape = h.shape.as_list()
        h = tf.reshape(h, [h_shape[0],1,1,h_shape[1]*h_shape[2]*h_shape[3]])
        h = nn.nin(h, 2*max_filters)
        hs.append(h)
        return hs
Beispiel #9
0
def dec_down(gs,
             zs_posterior,
             training,
             init=False,
             dropout_p=0.5,
             n_scales=1,
             n_residual_blocks=2,
             activation="elu",
             n_latent_scales=2):
    assert n_residual_blocks % 2 == 0
    gs = list(gs)
    zs_posterior = list(zs_posterior)
    with model_arg_scope(init=init, dropout_p=dropout_p,
                         activation=activation):
        # outputs
        hs = []  # hidden units
        ps = []  # priors
        zs = []  # prior samples
        # prepare input
        n_filters = gs[-1].shape.as_list()[-1]
        h = nn.nin(gs[-1], n_filters)
        for l in range(n_scales):
            # level module
            ## hidden units
            for i in range(n_residual_blocks // 2):
                h = nn.residual_block(h, gs.pop())
                hs.append(h)
            if l < n_latent_scales:
                ## prior
                spatial_shape = h.shape.as_list()[1]
                n_h_channels = h.shape.as_list()[-1]
                if spatial_shape == 1:
                    ### no spatial correlations
                    p = latent_parameters(h)
                    ps.append(p)
                    z_prior = latent_sample(p)
                    zs.append(z_prior)
                else:
                    ### four autoregressively modeled groups
                    if training:
                        z_posterior_groups = nn.split_groups(zs_posterior[0])
                    p_groups = []
                    z_groups = []
                    p_features = tf.space_to_depth(nn.residual_block(h), 2)
                    for i in range(4):
                        p_group = latent_parameters(p_features,
                                                    num_filters=n_h_channels)
                        p_groups.append(p_group)
                        z_group = latent_sample(p_group)
                        z_groups.append(z_group)
                        # ar feedback sampled from
                        if training:
                            feedback = z_posterior_groups.pop(0)
                        else:
                            feedback = z_group
                        # prepare input for next group
                        if i + 1 < 4:
                            p_features = nn.residual_block(
                                p_features, feedback)
                    if training:
                        assert not z_posterior_groups
                    # complete prior parameters
                    p = nn.merge_groups(p_groups)
                    ps.append(p)
                    # complete prior sample
                    z_prior = nn.merge_groups(z_groups)
                    zs.append(z_prior)
                ## vae feedback sampled from
                if training:
                    ## posterior
                    z = zs_posterior.pop(0)
                else:
                    ## prior
                    z = z_prior
                for i in range(n_residual_blocks // 2):
                    n_h_channels = h.shape.as_list()[-1]
                    h = tf.concat([h, z], axis=-1)
                    h = nn.nin(h, n_h_channels)
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
            else:
                for i in range(n_residual_blocks // 2):
                    h = nn.residual_block(h, gs.pop())
                    hs.append(h)
            # prepare input to next level
            if l + 1 < n_scales:
                n_filters = gs[-1].shape.as_list()[-1]
                h = nn.upsample(h, n_filters)

        assert not gs
        if training:
            assert not zs_posterior

        return hs, ps, zs