Exemplo n.º 1
0
def main(save_to='params', 
         dataset = 'mm',
         kl_loss='true', # use kl-div in z-space instead of mse
         diffs = 'false',
         seq_length = 30,
         num_epochs=1,
         lstm_n_hid=1024,
         max_per_epoch=-1
        ):
    kl_loss = kl_loss.lower() == 'true'
    diffs = diffs.lower() == 'true'

    # set up functions for data pre-processing and model training
    input_var = T.tensor4('inputs')

    # different experimental setup for moving mnist vs pulp fiction dataests
    if dataset == 'pf':
        img_size = 64
        cae_weights = c.pf_cae_params
        cae_specstr = c.pf_cae_specstr
        split_layer = 'conv7'
        inpvar = T.tensor4('input')
        net = m.build_cae(inpvar, specstr=cae_specstr, shape=(img_size, img_size))
        convs_from_img,_ = m.encoder_decoder(cae_weights, specstr=cae_specstr,
                layersplit=split_layer, shape=(img_size, img_size), poolinv=True)
        laydict = dict((l.name, l) for l in nn.layers.get_all_layers(net))
        zdec_in_shape = nn.layers.get_output_shape(laydict[split_layer])
        deconv_weights = c.pf_deconv_params
        vae_weights = c.pf_vae_params
        img_from_convs = m.deconvoluter(deconv_weights, specstr=cae_specstr, shape=zdec_in_shape)
        L=2
        vae_n_hid = 1500
        binary = False
        z_dim = 256
        l_tup = l_z_mu, l_z_ls, l_x_mu_list, l_x_ls_list, l_x_list, l_x = \
               m.build_vae(input_var, L=L, binary=binary, z_dim=z_dim, n_hid=vae_n_hid,
                        shape=(zdec_in_shape[2], zdec_in_shape[3]), channels=zdec_in_shape[1])
        u.load_params(l_x, vae_weights)
        datafile = 'data/pf.hdf5'
        frame_skip=3 # every 3rd frame in sequence
        z_decode_layer = l_x_mu_list[0]
        pixel_shift = 0.5
        samples_per_image = 4
        tr_batch_size = 16 # must be a multiple of samples_per_image
    elif dataset == 'mm':
        img_size = 64
        cvae_weights = c.mm_cvae_params
        L=2
        vae_n_hid = 1024
        binary = True
        z_dim = 32
        zdec_in_shape = (None, 1, img_size, img_size)
        l_tup = l_z_mu, l_z_ls, l_x_mu_list, l_x_ls_list, l_x_list, l_x = \
            m.build_vcae(input_var, L=L, z_dim=z_dim, n_hid=vae_n_hid, binary=binary,
                       shape=(zdec_in_shape[2], zdec_in_shape[3]), channels=zdec_in_shape[1])
        u.load_params(l_x, cvae_weights)
        datafile = 'data/moving_mnist.hdf5'
        frame_skip=1
        w,h=img_size,img_size # of raw input image in the hdf5 file
        z_decode_layer = l_x_list[0]
        pixel_shift = 0
        samples_per_image = 1
        tr_batch_size = 128 # must be a multiple of samples_per_image

    # functions for moving to/from image or conv-space, and z-space
    z_mat = T.matrix('z')
    zenc = theano.function([input_var], nn.layers.get_output(l_z_mu, deterministic=True))
    zdec = theano.function([z_mat], nn.layers.get_output(z_decode_layer, {l_z_mu:z_mat},
        deterministic=True).reshape((-1, zdec_in_shape[1]) + zdec_in_shape[2:]))
    zenc_ls = theano.function([input_var], nn.layers.get_output(l_z_ls, deterministic=True))

    # functions for encoding sequences of z's
    print 'compiling functions'
    z_var = T.tensor3('z_in')
    z_ls_var = T.tensor3('z_ls_in')
    tgt_mu_var = T.tensor3('z_tgt')
    tgt_ls_var = T.tensor3('z_ls_tgt')
    learning_rate = theano.shared(nn.utils.floatX(1e-4))

    # separate function definitions if we are using MSE and predicting only z, or KL divergence
    # and predicting both mean and sigma of z
    if kl_loss:
        def kl(p_mu, p_sigma, q_mu, q_sigma):
            return 0.5 * T.sum(T.sqr(p_sigma)/T.sqr(q_sigma) + T.sqr(q_mu - p_mu)/T.sqr(q_sigma)
                               - 1 + 2*T.log(q_sigma) - 2*T.log(p_sigma))
        lstm, _ = m.Z_VLSTM(z_var, z_ls_var, z_dim=z_dim, nhid=lstm_n_hid, training=True)
        z_mu_expr, z_ls_expr = nn.layers.get_output([lstm['output_mu'], lstm['output_ls']])
        z_mu_expr_det, z_ls_expr_det = nn.layers.get_output([lstm['output_mu'],
            lstm['output_ls']], deterministic=True)
        loss = kl(tgt_mu_var, T.exp(tgt_ls_var), z_mu_expr, T.exp(z_ls_expr))
        te_loss = kl(tgt_mu_var, T.exp(tgt_ls_var), z_mu_expr_det, T.exp(z_ls_expr_det))
        params = nn.layers.get_all_params(lstm['output'], trainable=True)
        updates = nn.updates.adam(loss, params, learning_rate=learning_rate)
        train_fn = theano.function([z_var, z_ls_var, tgt_mu_var, tgt_ls_var], loss, 
                updates=updates)
        test_fn = theano.function([z_var, z_ls_var, tgt_mu_var, tgt_ls_var], te_loss)
    else:
        lstm, _ = m.Z_LSTM(z_var, z_dim=z_dim, nhid=lstm_n_hid, training=True)
        loss = nn.objectives.squared_error(nn.layers.get_output(lstm['output']),
                tgt_mu_var).mean()
        te_loss = nn.objectives.squared_error(nn.layers.get_output(lstm['output'],
            deterministic=True), tgt_mu_var).mean()
        params = nn.layers.get_all_params(lstm['output'], trainable=True)
        updates = nn.updates.adam(loss, params, learning_rate=learning_rate)
        train_fn = theano.function([z_var, tgt_mu_var], loss, updates=updates)
        test_fn = theano.function([z_var, tgt_mu_var], te_loss)

    if dataset == 'pf':
        z_from_img = lambda x: zenc(convs_from_img(x))
        z_ls_from_img = lambda x: zenc_ls(convs_from_img(x))
        img_from_z = lambda z: img_from_convs(zdec(z))
    elif dataset == 'mm':
        z_from_img = zenc
        z_ls_from_img = zenc_ls
        img_from_z = zdec

    # training loop
    print('training for {} epochs'.format(num_epochs))
    nbatch = (seq_length+1) * tr_batch_size * frame_skip / samples_per_image
    data = u.DataH5PyStreamer(datafile, batch_size=nbatch)

    # for taking arrays of uint8 (non square) and converting them to batches of sequences
    def transform_data(ims_batch, center=False):
        imb = u.raw_to_floatX(ims_batch, pixel_shift=pixel_shift,
                center=center)[np.random.randint(frame_skip)::frame_skip]
        zbatch = np.zeros((tr_batch_size, seq_length+1, z_dim), dtype=theano.config.floatX)
        zsigbatch = np.zeros((tr_batch_size, seq_length+1, z_dim), dtype=theano.config.floatX)
        for i in xrange(samples_per_image):
            chunk = tr_batch_size/samples_per_image
            if diffs:
                zf = z_from_img(imb).reshape((chunk, seq_length+1, -1))
                zbatch[i*chunk:(i+1)*chunk, 1:] = zf[:,1:] - zf[:,:-1]
                if kl_loss:
                    zls = z_ls_from_img(imb).reshape((chunk, seq_length+1, -1))
                    zsigbatch[i*chunk:(i+1)*chunk, 1:] = zls[:,1:] - zls[:,:-1]
            else:
                zbatch[i*chunk:(i+1)*chunk] = z_from_img(imb).reshape((chunk, seq_length+1, -1))
                if kl_loss:
                    zsigbatch[i*chunk:(i+1)*chunk] = z_ls_from_img(imb).reshape((chunk,
                        seq_length+1, -1))
        if kl_loss:
            return zbatch[:,:-1,:], zsigbatch[:,:-1,:], zbatch[:,1:,:], zsigbatch[:,1:,:]
        return zbatch[:,:-1,:], zbatch[:,1:,:]

    # we need sequences of images, so we do not shuffle data during trainin
    hist = u.train_with_hdf5(data, num_epochs=num_epochs, train_fn=train_fn, test_fn=test_fn,
                     train_shuffle=False,
                     max_per_epoch=max_per_epoch,
                     tr_transform=lambda x: transform_data(x[0], center=False),
                     te_transform=lambda x: transform_data(x[0], center=True))

    hist = np.asarray(hist)
    u.save_params(lstm['output'], os.path.join(save_to, 'lstm_{}.npz'.format(hist[-1,-1])))

    # build functions to sample from LSTM
    # separate cell_init and hid_init from the other learned model parameters
    all_param_values = nn.layers.get_all_param_values(lstm['output'])
    init_indices = [i for i,p in enumerate(nn.layers.get_all_params(lstm['output']))
            if 'init' in str(p)]
    init_values = [all_param_values[i] for i in init_indices]
    params_noinit = [p for i,p in enumerate(all_param_values) if i not in init_indices]

    # build model without learnable init values, and load non-init parameters
    if kl_loss:
        lstm_sample, state_vars = m.Z_VLSTM(z_var, z_ls_var, z_dim=z_dim, nhid=lstm_n_hid,
                training=False)
    else:
        lstm_sample, state_vars = m.Z_LSTM(z_var, z_dim=z_dim, nhid=lstm_n_hid, training=False)
    nn.layers.set_all_param_values(lstm_sample['output'], params_noinit)

    # extract layers representing thee hidden and cell states, and have sample_fn
    # return their outputs
    state_layers_keys = [k for k in lstm_sample.keys() if 'hidfinal' in k or 'cellfinal' in k]
    state_layers_keys = sorted(state_layers_keys)
    state_layers_keys = sorted(state_layers_keys, key = lambda x:int(x.split('_')[1]))
    state_layers = [lstm_sample[s] for s in state_layers_keys]
    if kl_loss:
        sample_fn = theano.function([z_var, z_ls_var] + state_vars,
                nn.layers.get_output([lstm['output_mu'], lstm['output_ls']] + state_layers,
                    deterministic=True))
    else:
        sample_fn = theano.function([z_var] + state_vars,
                nn.layers.get_output([lstm['output']] + state_layers, deterministic=True))

    from images2gif import writeGif
    from PIL import Image

    # sample approximately 30 different generated video sequences
    te_stream = data.streamer(training=True, shuffled=False)
    interval = data.ntrain / data.batch_size / 30
    for idx,imb in enumerate(te_stream.get_epoch_iterator()):
        if idx % interval != 0:
            continue
        z_tup = transform_data(imb[0], center=True)
        seg_idx = np.random.randint(z_tup[0].shape[0])
        if kl_loss:
            z_in, z_ls_in = z_tup[0], z_tup[1]
            z_last, z_ls_last = z_in[seg_idx:seg_idx+1], z_ls_in[seg_idx:seg_idx+1]
            z_vars = [z_last, z_ls_last]
        else:
            z_in = z_tup[0]
            z_last = z_in[seg_idx:seg_idx+1]
            z_vars = [z_last]
        images = []
        state_values = [np.dot(np.ones((z_last.shape[0],1), dtype=theano.config.floatX), s)
                for s in init_values]
        output_list = sample_fn(*(z_vars + state_values))

        # use whole sequence of predictions for output
        z_pred = output_list[0]
        state_values = output_list[2 if kl_loss else 1:]

        rec = img_from_z(z_pred.reshape(-1, z_dim))
        for k in xrange(rec.shape[0]):
            images.append(Image.fromarray(u.get_picture_array(rec, index=k, shift=pixel_shift)))
        k += 1
        # slice prediction to feed into lstm
        z_pred = z_pred[:,-1:,:]
        if kl_loss:
            z_ls_pred = output_list[1][:,-1:,:]
            z_vars = [z_pred, z_ls_pred]
        else:
            z_vars = [z_pred]
        for i in xrange(30): # predict 30 frames after the end of the priming video
            output_list = sample_fn(*(z_vars + state_values))
            z_pred = output_list[0]
            state_values = output_list[2 if kl_loss else 1:]
            rec = img_from_z(z_pred.reshape(-1, z_dim))
            images.append(Image.fromarray(u.get_picture_array(rec, index=0, shift=pixel_shift)))
            if kl_loss:
                z_ls_pred = output_list[1]
                z_vars = [z_pred, z_ls_pred]
            else:
                z_vars = [z_pred]
        writeGif("sample_{}.gif".format(idx),images,duration=0.1,dither=0)
Exemplo n.º 2
0
def main(data_file='',
         num_epochs=10,
         batch_size=128,
         L=2,
         z_dim=256,
         n_hid=1500,
         binary='false',
         img_size=64,
         init_from='',
         save_to='params',
         split_layer='conv7',
         pxsh=0.5,
         specstr=c.pf_cae_specstr,
         cae_weights=c.pf_cae_params,
         deconv_weights=c.pf_deconv_params):
    binary = binary.lower() == 'true'

    # pre-trained function for extracting convolutional features from images
    cae = m.build_cae(input_var=None,
                      specstr=specstr,
                      shape=(img_size, img_size))
    laydict = dict((l.name, l) for l in nn.layers.get_all_layers(cae))
    convshape = nn.layers.get_output_shape(laydict[split_layer])
    convs_from_img, _ = m.encoder_decoder(cae_weights,
                                          specstr=specstr,
                                          layersplit=split_layer,
                                          shape=(img_size, img_size))
    # pre-trained function for returning to images from convolutional features
    img_from_convs = m.deconvoluter(deconv_weights,
                                    specstr=specstr,
                                    shape=convshape)

    # Create VAE model
    print("Building model and compiling functions...")
    print("L = {}, z_dim = {}, n_hid = {}, binary={}".format(
        L, z_dim, n_hid, binary))
    input_var = T.tensor4('inputs')
    c, w, h = convshape[1], convshape[2], convshape[3]
    l_tup = l_z_mu, l_z_ls, l_x_mu_list, l_x_ls_list, l_x_list, l_x = \
            m.build_vae(input_var, L=L, binary=binary, z_dim=z_dim, n_hid=n_hid,
                   shape=(w,h), channels=c)

    if len(init_from) > 0:
        print("loading from {}".format(init_from))
        u.load_params(l_x, init_from)

    # build loss, updates, training, prediction functions
    loss, _ = u.build_vae_loss(input_var,
                               *l_tup,
                               deterministic=False,
                               binary=binary,
                               L=L)
    test_loss, test_prediction = u.build_vae_loss(input_var,
                                                  *l_tup,
                                                  deterministic=True,
                                                  binary=binary,
                                                  L=L)

    lr = theano.shared(nn.utils.floatX(1e-5))
    params = nn.layers.get_all_params(l_x, trainable=True)
    updates = nn.updates.adam(loss, params, learning_rate=lr)
    train_fn = theano.function([input_var], loss, updates=updates)
    val_fn = theano.function([input_var], test_loss)
    ae_fn = theano.function([input_var], test_prediction)

    # run training loop
    def data_transform(x, do_center):
        floatx_ims = u.raw_to_floatX(x,
                                     pixel_shift=pxsh,
                                     square=True,
                                     center=do_center)
        return convs_from_img(floatx_ims)

    print("training for {} epochs".format(num_epochs))
    data = u.DataH5PyStreamer(data_file, batch_size=batch_size)
    hist = u.train_with_hdf5(
        data,
        num_epochs=num_epochs,
        train_fn=train_fn,
        test_fn=val_fn,
        tr_transform=lambda x: data_transform(x[0], do_center=False),
        te_transform=lambda x: data_transform(x[0], do_center=True))

    # generate examples, save training history
    te_stream = data.streamer(shuffled=True)
    imb, = next(te_stream.get_epoch_iterator())
    orig_feats = data_transform(imb, do_center=True)
    reconstructed_feats = ae_fn(orig_feats).reshape(orig_feats.shape)
    orig_feats_deconv = img_from_convs(orig_feats)
    reconstructed_feats_deconv = img_from_convs(reconstructed_feats)
    for i in range(reconstructed_feats_deconv.shape[0]):
        u.get_image_pair(orig_feats_deconv, reconstructed_feats_deconv, index=i, shift=pxsh)\
                .save('output_{}.jpg'.format(i))
    hist = np.asarray(hist)
    np.savetxt('vae_convs_train_hist.csv',
               np.asarray(hist),
               delimiter=',',
               fmt='%.5f')
    u.save_params(
        l_x, os.path.join(save_to, 'vae_convs_{}.npz'.format(hist[-1, -1])))
Exemplo n.º 3
0
def main(data_file = '', num_epochs=10, batch_size = 128, L=2, z_dim=256,
        n_hid=1500, binary='false', img_size = 64, init_from = '', save_to='params',
        split_layer='conv7', pxsh = 0.5, specstr = c.pf_cae_specstr,
        cae_weights=c.pf_cae_params, deconv_weights = c.pf_deconv_params):
    binary = binary.lower() == 'true'

    # pre-trained function for extracting convolutional features from images
    cae = m.build_cae(input_var=None, specstr=specstr, shape=(img_size,img_size))
    laydict = dict((l.name, l) for l in nn.layers.get_all_layers(cae))
    convshape = nn.layers.get_output_shape(laydict[split_layer])
    convs_from_img, _ = m.encoder_decoder(cae_weights, specstr=specstr, layersplit=split_layer,
            shape=(img_size, img_size))
    # pre-trained function for returning to images from convolutional features
    img_from_convs = m.deconvoluter(deconv_weights, specstr=specstr, shape=convshape)

    # Create VAE model
    print("Building model and compiling functions...")
    print("L = {}, z_dim = {}, n_hid = {}, binary={}".format(L, z_dim, n_hid, binary))
    input_var = T.tensor4('inputs')
    c,w,h = convshape[1], convshape[2], convshape[3]
    l_tup = l_z_mu, l_z_ls, l_x_mu_list, l_x_ls_list, l_x_list, l_x = \
            m.build_vae(input_var, L=L, binary=binary, z_dim=z_dim, n_hid=n_hid,
                   shape=(w,h), channels=c)

    if len(init_from) > 0:
        print("loading from {}".format(init_from))
        u.load_params(l_x, init_from)
    
    # build loss, updates, training, prediction functions
    loss,_ = u.build_vae_loss(input_var, *l_tup, deterministic=False, binary=binary, L=L)
    test_loss, test_prediction = u.build_vae_loss(input_var, *l_tup, deterministic=True,
            binary=binary, L=L)

    lr = theano.shared(nn.utils.floatX(1e-5))
    params = nn.layers.get_all_params(l_x, trainable=True)
    updates = nn.updates.adam(loss, params, learning_rate=lr)
    train_fn = theano.function([input_var], loss, updates=updates)
    val_fn = theano.function([input_var], test_loss)
    ae_fn = theano.function([input_var], test_prediction)

    # run training loop
    def data_transform(x, do_center):
        floatx_ims = u.raw_to_floatX(x, pixel_shift=pxsh, square=True, center=do_center)
        return convs_from_img(floatx_ims)

    print("training for {} epochs".format(num_epochs))
    data = u.DataH5PyStreamer(data_file, batch_size=batch_size)
    hist = u.train_with_hdf5(data, num_epochs=num_epochs, train_fn=train_fn, test_fn=val_fn,
                             tr_transform=lambda x: data_transform(x[0], do_center=False),
                             te_transform=lambda x: data_transform(x[0], do_center=True))

    # generate examples, save training history
    te_stream = data.streamer(shuffled=True)
    imb, = next(te_stream.get_epoch_iterator())
    orig_feats = data_transform(imb, do_center=True)
    reconstructed_feats = ae_fn(orig_feats).reshape(orig_feats.shape)
    orig_feats_deconv = img_from_convs(orig_feats)
    reconstructed_feats_deconv = img_from_convs(reconstructed_feats)
    for i in range(reconstructed_feats_deconv.shape[0]):
        u.get_image_pair(orig_feats_deconv, reconstructed_feats_deconv, index=i, shift=pxsh)\
                .save('output_{}.jpg'.format(i))
    hist = np.asarray(hist)
    np.savetxt('vae_convs_train_hist.csv', np.asarray(hist), delimiter=',', fmt='%.5f')
    u.save_params(l_x, os.path.join(save_to, 'vae_convs_{}.npz'.format(hist[-1,-1])))
Exemplo n.º 4
0
def main(
        save_to='params',
        dataset='mm',
        kl_loss='true',  # use kl-div in z-space instead of mse
        diffs='false',
        seq_length=30,
        num_epochs=1,
        lstm_n_hid=1024,
        max_per_epoch=-1):
    kl_loss = kl_loss.lower() == 'true'
    diffs = diffs.lower() == 'true'

    # set up functions for data pre-processing and model training
    input_var = T.tensor4('inputs')

    # different experimental setup for moving mnist vs pulp fiction dataests
    if dataset == 'pf':
        img_size = 64
        cae_weights = c.pf_cae_params
        cae_specstr = c.pf_cae_specstr
        split_layer = 'conv7'
        inpvar = T.tensor4('input')
        net = m.build_cae(inpvar,
                          specstr=cae_specstr,
                          shape=(img_size, img_size))
        convs_from_img, _ = m.encoder_decoder(cae_weights,
                                              specstr=cae_specstr,
                                              layersplit=split_layer,
                                              shape=(img_size, img_size),
                                              poolinv=True)
        laydict = dict((l.name, l) for l in nn.layers.get_all_layers(net))
        zdec_in_shape = nn.layers.get_output_shape(laydict[split_layer])
        deconv_weights = c.pf_deconv_params
        vae_weights = c.pf_vae_params
        img_from_convs = m.deconvoluter(deconv_weights,
                                        specstr=cae_specstr,
                                        shape=zdec_in_shape)
        L = 2
        vae_n_hid = 1500
        binary = False
        z_dim = 256
        l_tup = l_z_mu, l_z_ls, l_x_mu_list, l_x_ls_list, l_x_list, l_x = \
               m.build_vae(input_var, L=L, binary=binary, z_dim=z_dim, n_hid=vae_n_hid,
                        shape=(zdec_in_shape[2], zdec_in_shape[3]), channels=zdec_in_shape[1])
        u.load_params(l_x, vae_weights)
        datafile = 'data/pf.hdf5'
        frame_skip = 3  # every 3rd frame in sequence
        z_decode_layer = l_x_mu_list[0]
        pixel_shift = 0.5
        samples_per_image = 4
        tr_batch_size = 16  # must be a multiple of samples_per_image
    elif dataset == 'mm':
        img_size = 64
        cvae_weights = c.mm_cvae_params
        L = 2
        vae_n_hid = 1024
        binary = True
        z_dim = 32
        zdec_in_shape = (None, 1, img_size, img_size)
        l_tup = l_z_mu, l_z_ls, l_x_mu_list, l_x_ls_list, l_x_list, l_x = \
            m.build_vcae(input_var, L=L, z_dim=z_dim, n_hid=vae_n_hid, binary=binary,
                       shape=(zdec_in_shape[2], zdec_in_shape[3]), channels=zdec_in_shape[1])
        u.load_params(l_x, cvae_weights)
        datafile = 'data/moving_mnist.hdf5'
        frame_skip = 1
        w, h = img_size, img_size  # of raw input image in the hdf5 file
        z_decode_layer = l_x_list[0]
        pixel_shift = 0
        samples_per_image = 1
        tr_batch_size = 128  # must be a multiple of samples_per_image

    # functions for moving to/from image or conv-space, and z-space
    z_mat = T.matrix('z')
    zenc = theano.function([input_var],
                           nn.layers.get_output(l_z_mu, deterministic=True))
    zdec = theano.function(
        [z_mat],
        nn.layers.get_output(
            z_decode_layer, {
                l_z_mu: z_mat
            }, deterministic=True).reshape((-1, zdec_in_shape[1]) +
                                           zdec_in_shape[2:]))
    zenc_ls = theano.function([input_var],
                              nn.layers.get_output(l_z_ls, deterministic=True))

    # functions for encoding sequences of z's
    print 'compiling functions'
    z_var = T.tensor3('z_in')
    z_ls_var = T.tensor3('z_ls_in')
    tgt_mu_var = T.tensor3('z_tgt')
    tgt_ls_var = T.tensor3('z_ls_tgt')
    learning_rate = theano.shared(nn.utils.floatX(1e-4))

    # separate function definitions if we are using MSE and predicting only z, or KL divergence
    # and predicting both mean and sigma of z
    if kl_loss:

        def kl(p_mu, p_sigma, q_mu, q_sigma):
            return 0.5 * T.sum(
                T.sqr(p_sigma) / T.sqr(q_sigma) + T.sqr(q_mu - p_mu) /
                T.sqr(q_sigma) - 1 + 2 * T.log(q_sigma) - 2 * T.log(p_sigma))

        lstm, _ = m.Z_VLSTM(z_var,
                            z_ls_var,
                            z_dim=z_dim,
                            nhid=lstm_n_hid,
                            training=True)
        z_mu_expr, z_ls_expr = nn.layers.get_output(
            [lstm['output_mu'], lstm['output_ls']])
        z_mu_expr_det, z_ls_expr_det = nn.layers.get_output(
            [lstm['output_mu'], lstm['output_ls']], deterministic=True)
        loss = kl(tgt_mu_var, T.exp(tgt_ls_var), z_mu_expr, T.exp(z_ls_expr))
        te_loss = kl(tgt_mu_var, T.exp(tgt_ls_var), z_mu_expr_det,
                     T.exp(z_ls_expr_det))
        params = nn.layers.get_all_params(lstm['output'], trainable=True)
        updates = nn.updates.adam(loss, params, learning_rate=learning_rate)
        train_fn = theano.function([z_var, z_ls_var, tgt_mu_var, tgt_ls_var],
                                   loss,
                                   updates=updates)
        test_fn = theano.function([z_var, z_ls_var, tgt_mu_var, tgt_ls_var],
                                  te_loss)
    else:
        lstm, _ = m.Z_LSTM(z_var, z_dim=z_dim, nhid=lstm_n_hid, training=True)
        loss = nn.objectives.squared_error(
            nn.layers.get_output(lstm['output']), tgt_mu_var).mean()
        te_loss = nn.objectives.squared_error(
            nn.layers.get_output(lstm['output'], deterministic=True),
            tgt_mu_var).mean()
        params = nn.layers.get_all_params(lstm['output'], trainable=True)
        updates = nn.updates.adam(loss, params, learning_rate=learning_rate)
        train_fn = theano.function([z_var, tgt_mu_var], loss, updates=updates)
        test_fn = theano.function([z_var, tgt_mu_var], te_loss)

    if dataset == 'pf':
        z_from_img = lambda x: zenc(convs_from_img(x))
        z_ls_from_img = lambda x: zenc_ls(convs_from_img(x))
        img_from_z = lambda z: img_from_convs(zdec(z))
    elif dataset == 'mm':
        z_from_img = zenc
        z_ls_from_img = zenc_ls
        img_from_z = zdec

    # training loop
    print('training for {} epochs'.format(num_epochs))
    nbatch = (seq_length + 1) * tr_batch_size * frame_skip / samples_per_image
    data = u.DataH5PyStreamer(datafile, batch_size=nbatch)

    # for taking arrays of uint8 (non square) and converting them to batches of sequences
    def transform_data(ims_batch, center=False):
        imb = u.raw_to_floatX(
            ims_batch, pixel_shift=pixel_shift,
            center=center)[np.random.randint(frame_skip)::frame_skip]
        zbatch = np.zeros((tr_batch_size, seq_length + 1, z_dim),
                          dtype=theano.config.floatX)
        zsigbatch = np.zeros((tr_batch_size, seq_length + 1, z_dim),
                             dtype=theano.config.floatX)
        for i in xrange(samples_per_image):
            chunk = tr_batch_size / samples_per_image
            if diffs:
                zf = z_from_img(imb).reshape((chunk, seq_length + 1, -1))
                zbatch[i * chunk:(i + 1) * chunk, 1:] = zf[:, 1:] - zf[:, :-1]
                if kl_loss:
                    zls = z_ls_from_img(imb).reshape(
                        (chunk, seq_length + 1, -1))
                    zsigbatch[i * chunk:(i + 1) * chunk,
                              1:] = zls[:, 1:] - zls[:, :-1]
            else:
                zbatch[i * chunk:(i + 1) * chunk] = z_from_img(imb).reshape(
                    (chunk, seq_length + 1, -1))
                if kl_loss:
                    zsigbatch[i * chunk:(i + 1) *
                              chunk] = z_ls_from_img(imb).reshape(
                                  (chunk, seq_length + 1, -1))
        if kl_loss:
            return zbatch[:, :
                          -1, :], zsigbatch[:, :
                                            -1, :], zbatch[:,
                                                           1:, :], zsigbatch[:,
                                                                             1:, :]
        return zbatch[:, :-1, :], zbatch[:, 1:, :]

    # we need sequences of images, so we do not shuffle data during trainin
    hist = u.train_with_hdf5(
        data,
        num_epochs=num_epochs,
        train_fn=train_fn,
        test_fn=test_fn,
        train_shuffle=False,
        max_per_epoch=max_per_epoch,
        tr_transform=lambda x: transform_data(x[0], center=False),
        te_transform=lambda x: transform_data(x[0], center=True))

    hist = np.asarray(hist)
    u.save_params(lstm['output'],
                  os.path.join(save_to, 'lstm_{}.npz'.format(hist[-1, -1])))

    # build functions to sample from LSTM
    # separate cell_init and hid_init from the other learned model parameters
    all_param_values = nn.layers.get_all_param_values(lstm['output'])
    init_indices = [
        i for i, p in enumerate(nn.layers.get_all_params(lstm['output']))
        if 'init' in str(p)
    ]
    init_values = [all_param_values[i] for i in init_indices]
    params_noinit = [
        p for i, p in enumerate(all_param_values) if i not in init_indices
    ]

    # build model without learnable init values, and load non-init parameters
    if kl_loss:
        lstm_sample, state_vars = m.Z_VLSTM(z_var,
                                            z_ls_var,
                                            z_dim=z_dim,
                                            nhid=lstm_n_hid,
                                            training=False)
    else:
        lstm_sample, state_vars = m.Z_LSTM(z_var,
                                           z_dim=z_dim,
                                           nhid=lstm_n_hid,
                                           training=False)
    nn.layers.set_all_param_values(lstm_sample['output'], params_noinit)

    # extract layers representing thee hidden and cell states, and have sample_fn
    # return their outputs
    state_layers_keys = [
        k for k in lstm_sample.keys() if 'hidfinal' in k or 'cellfinal' in k
    ]
    state_layers_keys = sorted(state_layers_keys)
    state_layers_keys = sorted(state_layers_keys,
                               key=lambda x: int(x.split('_')[1]))
    state_layers = [lstm_sample[s] for s in state_layers_keys]
    if kl_loss:
        sample_fn = theano.function(
            [z_var, z_ls_var] + state_vars,
            nn.layers.get_output([lstm['output_mu'], lstm['output_ls']] +
                                 state_layers,
                                 deterministic=True))
    else:
        sample_fn = theano.function([z_var] + state_vars,
                                    nn.layers.get_output([lstm['output']] +
                                                         state_layers,
                                                         deterministic=True))

    from images2gif import writeGif
    from PIL import Image

    # sample approximately 30 different generated video sequences
    te_stream = data.streamer(training=True, shuffled=False)
    interval = data.ntrain / data.batch_size / 30
    for idx, imb in enumerate(te_stream.get_epoch_iterator()):
        if idx % interval != 0:
            continue
        z_tup = transform_data(imb[0], center=True)
        seg_idx = np.random.randint(z_tup[0].shape[0])
        if kl_loss:
            z_in, z_ls_in = z_tup[0], z_tup[1]
            z_last, z_ls_last = z_in[seg_idx:seg_idx +
                                     1], z_ls_in[seg_idx:seg_idx + 1]
            z_vars = [z_last, z_ls_last]
        else:
            z_in = z_tup[0]
            z_last = z_in[seg_idx:seg_idx + 1]
            z_vars = [z_last]
        images = []
        state_values = [
            np.dot(np.ones((z_last.shape[0], 1), dtype=theano.config.floatX),
                   s) for s in init_values
        ]
        output_list = sample_fn(*(z_vars + state_values))

        # use whole sequence of predictions for output
        z_pred = output_list[0]
        state_values = output_list[2 if kl_loss else 1:]

        rec = img_from_z(z_pred.reshape(-1, z_dim))
        for k in xrange(rec.shape[0]):
            images.append(
                Image.fromarray(
                    u.get_picture_array(rec, index=k, shift=pixel_shift)))
        k += 1
        # slice prediction to feed into lstm
        z_pred = z_pred[:, -1:, :]
        if kl_loss:
            z_ls_pred = output_list[1][:, -1:, :]
            z_vars = [z_pred, z_ls_pred]
        else:
            z_vars = [z_pred]
        for i in xrange(
                30):  # predict 30 frames after the end of the priming video
            output_list = sample_fn(*(z_vars + state_values))
            z_pred = output_list[0]
            state_values = output_list[2 if kl_loss else 1:]
            rec = img_from_z(z_pred.reshape(-1, z_dim))
            images.append(
                Image.fromarray(
                    u.get_picture_array(rec, index=0, shift=pixel_shift)))
            if kl_loss:
                z_ls_pred = output_list[1]
                z_vars = [z_pred, z_ls_pred]
            else:
                z_vars = [z_pred]
        writeGif("sample_{}.gif".format(idx), images, duration=0.1, dither=0)