def cond_erg_enc_net(_incoming, _cond, noise_sigma=0.05, drop_ratio_conv=0.1):
    #_noise = L.GaussianNoiseLayer(_incoming, sigma=noise_sigma)
    _drop1 = L.DropoutLayer(_incoming, p=drop_ratio_conv, rescale=True)
    _drop1 = plu.concat_tc(_drop1, _cond)
    _conv1 = batch_norm(
        conv(_drop1,
             num_filters=128,
             filter_size=4,
             stride=2,
             pad=1,
             W=I.Normal(0.02),
             b=None,
             nonlinearity=NL.LeakyRectify(0.02)))

    _drop2 = L.DropoutLayer(_conv1, p=drop_ratio_conv, rescale=True)
    _drop2 = plu.concat_tc(_drop2, _cond)
    _emb = batch_norm(
        conv(_drop2,
             num_filters=256,
             filter_size=4,
             stride=2,
             pad=1,
             W=I.Normal(0.02),
             b=None,
             nonlinearity=NL.LeakyRectify(0.02)))
    return _emb
def cond_gen_net(_seed, _cond, num_noise_slices=[0, 0, 0, 0, 0]):
    _seed0, ns = plu.concat_tn(None, _seed, 0, num_noise_slices[0])
    _seed0 = plu.concat_tc(_seed0, _cond)
    _fc1 = batch_norm(
        L.DenseLayer(_seed0,
                     256 * 4**2,
                     W=I.Normal(0.02),
                     b=None,
                     nonlinearity=NL.rectify))
    _reshape1 = L.ReshapeLayer(_fc1, ([0], 256, 4, 4))

    _deconv2, ns = plu.concat_tcn(_reshape1, _cond, _seed, ns,
                                  num_noise_slices[1])
    _deconv2 = batch_norm(
        deconv(_deconv2,
               num_filters=384,
               filter_size=3,
               stride=1,
               crop=0,
               W=I.Normal(0.02),
               b=None,
               nonlinearity=NL.rectify))

    _deconv3, ns = plu.concat_tcn(_deconv2, _cond, _seed, ns,
                                  num_noise_slices[2])
    _deconv3 = batch_norm(
        deconv(_deconv3,
               num_filters=256,
               filter_size=4,
               stride=2,
               crop=0,
               W=I.Normal(0.02),
               b=None,
               nonlinearity=NL.rectify))

    _deconv4, ns = plu.concat_tcn(_deconv3, _cond, _seed, ns,
                                  num_noise_slices[3])
    _deconv4 = batch_norm(
        deconv(_deconv4,
               num_filters=96,
               filter_size=4,
               stride=2,
               crop=0,
               W=I.Normal(0.02),
               b=None,
               nonlinearity=NL.rectify))

    _deconv5, ns = plu.concat_tcn(_deconv4, _cond, _seed, ns,
                                  num_noise_slices[4])
    _deconv5 = deconv(_deconv5,
                      num_filters=npc,
                      filter_size=3,
                      stride=1,
                      crop=0,
                      W=I.Normal(0.02),
                      b=None,
                      nonlinearity=NL.tanh)
    print "===> graph requires nz>=%d <===" % ns
    return _deconv5
def cond_erg_dec_net(_emb, _cond):
    _deconv2 = plu.concat_tc(_emb, _cond)
    _deconv2 = deconv(_deconv2,
                      num_filters=128,
                      filter_size=4,
                      stride=2,
                      crop=1,
                      W=I.Normal(0.02),
                      b=I.Constant(0),
                      nonlinearity=NL.LeakyRectify(0.02))

    _deconv1 = plu.concat_tc(_deconv2, _cond)
    _deconv1 = deconv(_deconv1,
                      num_filters=npc,
                      filter_size=4,
                      stride=2,
                      crop=1,
                      W=I.Normal(0.02),
                      b=I.Constant(0),
                      nonlinearity=None)
    return _deconv1
def erg_dec_net(_emb, _cond):
    if _cond != None:
        _conv3 = plu.concat_tc(_emb, _cond)
    else:
        _conv3 = _emb
    _deconv1 = batch_norm(
        deconv(_conv3,
               num_filters=128,
               filter_size=3,
               stride=1,
               crop=0,
               W=I.Normal(0.02),
               b=None,
               nonlinearity=NL.LeakyRectify(0.01)))
    #
    if _cond != None:
        _deconv1 = plu.concat_tc(_deconv1, _cond)
    _deconv2 = deconv(_deconv1,
                      num_filters=64,
                      filter_size=4,
                      stride=2,
                      crop=0,
                      W=I.Normal(0.02),
                      b=None,
                      nonlinearity=NL.LeakyRectify(0.01))
    #
    if _cond != None:
        _deconv2 = plu.concat_tc(_deconv2, _cond)
    _deconv3 = deconv(_deconv2,
                      num_filters=npc,
                      filter_size=4,
                      stride=2,
                      crop=1,
                      W=I.Normal(0.02),
                      b=None,
                      nonlinearity=None)
    return _deconv3
def erg_enc_net(_input, _cond=None):
    if _cond != None:
        _input = plu.concat_tc(_input, _cond)
    _conv1 = batch_norm(
        conv(_input,
             num_filters=64,
             filter_size=4,
             stride=2,
             pad=1,
             W=I.Normal(0.02),
             b=None,
             nonlinearity=NL.LeakyRectify(0.1)))
    #
    if _cond != None:
        _conv1 = plu.concat_tc(_conv1, _cond)
    _conv2 = batch_norm(
        conv(_conv1,
             num_filters=128,
             filter_size=4,
             stride=2,
             pad=0,
             W=I.Normal(0.02),
             b=None,
             nonlinearity=NL.LeakyRectify(0.1)))
    #
    if _cond != None:
        _conv2 = plu.concat_tc(_conv2, _cond)
    _emb = batch_norm(
        conv(_conv2,
             num_filters=256,
             filter_size=3,
             stride=1,
             pad=0,
             W=I.Normal(0.02),
             b=None,
             nonlinearity=NL.LeakyRectify(0.1)))
    return _emb
def gen_net(_seed, num_noise_slices, _cond=None):
    _fc1, ns = plu.concat_tn(None, _seed, 0, num_noise_slices[0])
    if _cond != None:
        _fc1 = plu.concat_tc(_fc1, _cond)
    _fc1 = batch_norm(
        L.DenseLayer(_fc1,
                     1024,
                     W=I.Normal(0.02),
                     b=None,
                     nonlinearity=NL.rectify))
    #
    if _cond != None:
        _fc2, ns = plu.concat_tcn(_fc1, _cond, _seed, ns, num_noise_slices[1])
    else:
        _fc2, ns = plu.concat_tn(_fc1, _seed, ns, num_noise_slices[1])
    #
    _fc2 = batch_norm(
        L.DenseLayer(_fc2,
                     128 * 4**2,
                     W=I.Normal(0.02),
                     b=None,
                     nonlinearity=NL.rectify))
    _reshape2 = L.ReshapeLayer(_fc2, ([0], 128, 4, 4))
    if _cond != None:
        _reshape2, ns = plu.concat_tcn(_reshape2, _cond, _seed, ns,
                                       num_noise_slices[2])
    else:
        _reshape2, ns = plu.concat_tn(_reshape2, _seed, ns,
                                      num_noise_slices[2])
    #
    _deconv3 = batch_norm(
        deconv(_reshape2,
               num_filters=128,
               filter_size=3,
               stride=1,
               crop=0,
               W=I.Normal(0.02),
               b=None,
               nonlinearity=NL.rectify))
    if _cond != None:
        _deconv4, ns = plu.concat_tcn(_deconv3, _cond, _seed, ns,
                                      num_noise_slices[3])
    else:
        _deconv4, ns = plu.concat_tn(_deconv3, _seed, ns, num_noise_slices[3])
    #
    _deconv4 = batch_norm(
        deconv(_deconv4,
               num_filters=64,
               filter_size=4,
               stride=2,
               crop=0,
               W=I.Normal(0.02),
               b=None,
               nonlinearity=NL.rectify))
    if _cond != None:
        _deconv5, ns = plu.concat_tcn(_deconv4, _cond, _seed, ns,
                                      num_noise_slices[4])
    else:
        _deconv5, ns = plu.concat_tn(_deconv4, _seed, ns, num_noise_slices[4])
    _deconv5 = deconv(_deconv5,
                      num_filters=npc,
                      filter_size=4,
                      stride=2,
                      crop=1,
                      W=I.Normal(0.02),
                      b=None,
                      nonlinearity=NL.sigmoid)
    print "===> graph requires nz>=%d <===" % ns
    return _deconv5