def step2(x1, x2, x3, x4, h, mem):
    h = h.dimshuffle('x', 0) # 1x20
    h1 = T.dot(h, W1) # 1x10
    h2 = T.dot(h, W2)
    h3 = T.dot(h, W3)
    h4 = T.dot(h, W4)
    test = (T.dot(x1, T.transpose(h1)) +
            T.dot(x2, T.transpose(h2)) +
            T.dot(x3, T.transpose(h3)) +
            T.dot(x4, T.transpose(h4)))
    return test + mem
def step3(x, h, mem):
    x1 = x[:, 0:10] # seq_lenx10
    x2 = x[:, 10:20]
    x3 = x[:, 20:30]
    x4 = x[:, 30:40]

    h = h.dimshuffle('x', 0) # 1x20
    h1 = T.dot(h, W1) # 1x10
    h2 = T.dot(h, W2)
    h3 = T.dot(h, W3)
    h4 = T.dot(h, W4)
    test = (T.dot(x1, T.transpose(h1)) +
            T.dot(x2, T.transpose(h2)) +
            T.dot(x3, T.transpose(h3)) +
            T.dot(x4, T.transpose(h4)))
    return test + mem
def step1(x, h, mem):
    x1 = x[:, 0:10] # seq_lenx10
    x2 = x[:, 10:20]
    x3 = x[:, 20:30]
    x4 = x[:, 30:40]

    h = T.dot(h.dimshuffle('x', 0), W)
    h1 = h[:, 0:10] # 1x10
    h2 = h[:, 10:20]
    h3 = h[:, 20:30]
    h4 = h[:, 30:40]

    test = (T.dot(x1, T.transpose(h1)) +
            T.dot(x2, T.transpose(h2)) +
            T.dot(x3, T.transpose(h3)) +
            T.dot(x4, T.transpose(h4)))
    return test + mem
예제 #4
0
def test_vae():
    ds = odin.dataset.load_mnist()

    W = T.variable(T.np_glorot_uniform(shape=(784, 512)), name='W')
    WT = T.transpose(W)
    encoder = odin.nnet.Dense((None, 28, 28), num_units=512, W=W, name='encoder')
    # decoder = odin.nnet.Dense((None, 256), num_units=512, name='decoder1')
    decoder = odin.nnet.Dense((None, 512), num_units=784, W=WT, name='decoder2')

    vae = odin.nnet.VariationalEncoderDecoder(encoder=encoder, decoder=decoder,
        prior_logsigma=1.7, batch_size=64)

    # ====== prediction ====== #
    x = ds['X_train'][:16]

    f = T.function(inputs=vae.input_var, outputs=vae(training=False))
    print("Predictions:", f(x)[0].shape)

    f = T.function(
        inputs=vae.input_var,
        outputs=vae.set_reconstruction_mode(True)(training=False))
    y = f(x)[0].reshape(-1, 28, 28)
    print("Predictions:", y.shape)

    odin.visual.plot_images(x)
    odin.visual.plot_images(y)
    odin.visual.plot_show()

    print('Params:', [p.name for p in vae.get_params(False)])
    print('Params(globals):', [p.name for p in vae.get_params(True)])
    # ====== Optimizer ====== #
    cost, updates = vae.get_optimization(
        objective=odin.objectives.categorical_crossentropy,
        optimizer=lambda x, y: odin.optimizers.sgd(x, y, learning_rate=0.01),
        globals=True, training=True)

    f = T.function(inputs=vae.input_var, outputs=cost, updates=updates)
    cost = []
    niter = ds['X_train'].iter_len() / 64
    for j in xrange(2):
        for i, x in enumerate(ds['X_train'].iter(64)):
            if x.shape[0] != 64: continue
            cost.append(f(x))
            odin.logger.progress(i, niter, str(cost[-1]))
    odin.visual.print_bar(cost)

    # ====== reconstruc ====== #
    f = T.function(
        inputs=vae.input_var,
        outputs=vae.set_reconstruction_mode(True)(training=False))
    X_test = ds['X_test'][:16]
    X_reco = f(X_test)[0].reshape(-1, 28, 28)
    odin.visual.plot_images(X_test)
    odin.visual.plot_images(X_reco)
    odin.visual.plot_show()