Exemple #1
0
    def cost(self, readouts, outputs):
        mu, sigma, binary = self.components(readouts)
        # Add term to model nll of voiced
        outputs_shape = outputs.shape
        outputs_ndim = outputs.ndim
        #outputs_shape[-1] should be 2
        outputs = outputs.reshape((-1, outputs_shape[-1]))
        outputs = outputs.T

        #ipdb.set_trace()
        sp = outputs[:-2]
        sp = sp.T
        sp = sp.reshape(tensor.set_subtensor(outputs_shape[-1], -1),
                        ndim=outputs_ndim)
        #tensor.set_subtensor(outputs_shape[-1], -1)
        f0 = outputs[-2]
        voiced = outputs[-1]
        f0 = f0.reshape(outputs_shape[:-1], ndim=outputs_ndim - 1)
        voiced = voiced.reshape(outputs_shape[:-1], ndim=outputs_ndim - 1)

        f0 = f0.dimshuffle(*range(f0.ndim) + ['x'])

        binary = binary.flatten(ndim=binary.ndim - 1)

        c_b = tensor.xlogx.xlogy0(voiced, binary) + tensor.xlogx.xlogy0(
            1 - voiced, 1 - binary)

        cost_gmm = self.gmm_emitter.cost(readouts, sp)
        return Gaussian(f0, mu, sigma) * voiced - c_b + cost_gmm
Exemple #2
0
 def cost(self, X):
     if len(X) != 3:
         raise ValueError("The number of inputs does not match.")
     cost = Gaussian(X[0], X[1], X[2])
     if self.use_sum:
         return cost.sum()
     else:
         return cost.mean()
Exemple #3
0
    def cost(self, readouts, outputs):
        mu, sigma, binary = self.components(readouts)
        # Add term to model nll of voiced
        outputs_shape = outputs.shape
        outputs_ndim = outputs.ndim
        #outputs_shape[-1] should be 2
        outputs = outputs.reshape((-1, outputs_shape[-1]))
        outputs = outputs.T
        f0 = outputs[0]
        voiced = outputs[1]
        f0 = f0.reshape(outputs_shape[:-1], ndim=outputs_ndim - 1)
        voiced = voiced.reshape(outputs_shape[:-1], ndim=outputs_ndim - 1)

        f0 = f0.dimshuffle(*range(f0.ndim) + ['x'])

        binary = binary.flatten(ndim=binary.ndim - 1)

        c_b = tensor.xlogx.xlogy0(voiced, binary) + tensor.xlogx.xlogy0(
            1 - voiced, 1 - binary)

        return Gaussian(f0, mu, sigma) * voiced - c_b
Exemple #4
0
def main(args):

    trial = int(args['trial'])
    pkl_name = 'rnn_gauss_%d' % trial
    channel_name = 'valid_nll'

    data_path = args['data_path']
    save_path = args['save_path']

    monitoring_freq = int(args['monitoring_freq'])
    force_saving_freq = int(args['force_saving_freq'])
    reset_freq = int(args['reset_freq'])
    epoch = int(args['epoch'])
    batch_size = int(args['batch_size'])
    m_batch_size = int(args['m_batch_size'])
    x_dim = int(args['x_dim'])
    z_dim = int(args['z_dim'])
    rnn_dim = int(args['rnn_dim'])
    lr = float(args['lr'])
    debug = int(args['debug'])

    print "trial no. %d" % trial
    print "batch size %d" % batch_size
    print "learning rate %f" % lr
    print "saving pkl file '%s'" % pkl_name
    print "to the save path '%s'" % save_path

    x2s_dim = 800
    s2x_dim = 800
    target_dim = x_dim

    file_name = 'blizzard_unseg_tbptt'
    normal_params = np.load(data_path + file_name + '_normal.npz')
    X_mean = normal_params['X_mean']
    X_std = normal_params['X_std']

    model = Model()
    train_data = Blizzard_tbptt(name='train',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    valid_data = Blizzard_tbptt(name='valid',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    x = train_data.theano_vars()
    m_x = valid_data.theano_vars()

    if debug:
        x.tag.test_value = np.zeros((15, batch_size, x_dim),
                                    dtype=theano.config.floatX)
        m_x.tag.test_value = np.zeros((15, m_batch_size, x_dim),
                                      dtype=theano.config.floatX)

    init_W = InitCell('rand')
    init_U = InitCell('ortho')
    init_b = InitCell('zeros')
    init_b_sig = InitCell('const', mean=0.6)

    x_1 = FullyConnectedLayer(name='x_1',
                              parent=['x_t'],
                              parent_dim=[x_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_2 = FullyConnectedLayer(name='x_2',
                              parent=['x_1'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_3 = FullyConnectedLayer(name='x_3',
                              parent=['x_2'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_4 = FullyConnectedLayer(name='x_4',
                              parent=['x_3'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    rnn = LSTM(name='rnn',
               parent=['x_4'],
               parent_dim=[x2s_dim],
               nout=rnn_dim,
               unit='tanh',
               init_W=init_W,
               init_U=init_U,
               init_b=init_b)

    theta_1 = FullyConnectedLayer(name='theta_1',
                                  parent=['s_tm1'],
                                  parent_dim=[rnn_dim],
                                  nout=s2x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_2 = FullyConnectedLayer(name='theta_2',
                                  parent=['theta_1'],
                                  parent_dim=[s2x_dim],
                                  nout=s2x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_3 = FullyConnectedLayer(name='theta_3',
                                  parent=['theta_2'],
                                  parent_dim=[s2x_dim],
                                  nout=s2x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_4 = FullyConnectedLayer(name='theta_4',
                                  parent=['theta_3'],
                                  parent_dim=[s2x_dim],
                                  nout=s2x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_mu = FullyConnectedLayer(name='theta_mu',
                                   parent=['theta_4'],
                                   parent_dim=[s2x_dim],
                                   nout=target_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    theta_sig = FullyConnectedLayer(name='theta_sig',
                                    parent=['theta_4'],
                                    parent_dim=[s2x_dim],
                                    nout=target_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    nodes = [
        rnn, x_1, x_2, x_3, x_4, theta_1, theta_2, theta_3, theta_4, theta_mu,
        theta_sig
    ]

    params = OrderedDict()
    for node in nodes:
        if node.initialize() is not None:
            params.update(node.initialize())
    params = init_tparams(params)

    step_count = sharedX(0, name='step_count')
    last_rnn = np.zeros((batch_size, rnn_dim * 2), dtype=theano.config.floatX)
    rnn_tm1 = sharedX(last_rnn, name='rnn_tm1')
    shared_updates = OrderedDict()
    shared_updates[step_count] = step_count + 1

    s_0 = T.switch(T.eq(T.mod(step_count, reset_freq), 0),
                   rnn.get_init_state(batch_size), rnn_tm1)

    x_1_temp = x_1.fprop([x], params)
    x_2_temp = x_2.fprop([x_1_temp], params)
    x_3_temp = x_3.fprop([x_2_temp], params)
    x_4_temp = x_4.fprop([x_3_temp], params)

    def inner_fn(x_t, s_tm1):

        s_t = rnn.fprop([[x_t], [s_tm1]], params)

        return s_t

    (s_temp, updates) = theano.scan(fn=inner_fn,
                                    sequences=[x_4_temp],
                                    outputs_info=[s_0])

    for k, v in updates.iteritems():
        k.default_update = v

    shared_updates[rnn_tm1] = s_temp[-1]
    s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)
    theta_1_temp = theta_1.fprop([s_temp], params)
    theta_2_temp = theta_2.fprop([theta_1_temp], params)
    theta_3_temp = theta_3.fprop([theta_2_temp], params)
    theta_4_temp = theta_4.fprop([theta_3_temp], params)
    theta_mu_temp = theta_mu.fprop([theta_4_temp], params)
    theta_sig_temp = theta_sig.fprop([theta_4_temp], params)

    recon = Gaussian(x, theta_mu_temp, theta_sig_temp)
    recon_term = recon.mean()
    recon_term.name = 'nll'

    m_x_1_temp = x_1.fprop([m_x], params)
    m_x_2_temp = x_2.fprop([m_x_1_temp], params)
    m_x_3_temp = x_3.fprop([m_x_2_temp], params)
    m_x_4_temp = x_4.fprop([m_x_3_temp], params)

    m_s_0 = rnn.get_init_state(m_batch_size)

    (m_s_temp, m_updates) = theano.scan(fn=inner_fn,
                                        sequences=[m_x_4_temp],
                                        outputs_info=[m_s_0])

    for k, v in m_updates.iteritems():
        k.default_update = v

    m_s_temp = concatenate([m_s_0[None, :, :], m_s_temp[:-1]], axis=0)
    m_theta_1_temp = theta_1.fprop([m_s_temp], params)
    m_theta_2_temp = theta_2.fprop([m_theta_1_temp], params)
    m_theta_3_temp = theta_3.fprop([m_theta_2_temp], params)
    m_theta_4_temp = theta_4.fprop([m_theta_3_temp], params)
    m_theta_mu_temp = theta_mu.fprop([m_theta_4_temp], params)
    m_theta_sig_temp = theta_sig.fprop([m_theta_4_temp], params)

    m_recon = Gaussian(m_x, m_theta_mu_temp, m_theta_sig_temp)
    m_recon_term = m_recon.mean()
    m_recon_term.name = 'nll'

    max_x = m_x.max()
    mean_x = m_x.mean()
    min_x = m_x.min()
    max_x.name = 'max_x'
    mean_x.name = 'mean_x'
    min_x.name = 'min_x'

    max_theta_mu = m_theta_mu_temp.max()
    mean_theta_mu = m_theta_mu_temp.mean()
    min_theta_mu = m_theta_mu_temp.min()
    max_theta_mu.name = 'max_theta_mu'
    mean_theta_mu.name = 'mean_theta_mu'
    min_theta_mu.name = 'min_theta_mu'

    max_theta_sig = m_theta_sig_temp.max()
    mean_theta_sig = m_theta_sig_temp.mean()
    min_theta_sig = m_theta_sig_temp.min()
    max_theta_sig.name = 'max_theta_sig'
    mean_theta_sig.name = 'mean_theta_sig'
    min_theta_sig.name = 'min_theta_sig'

    model.inputs = [x]
    model.params = params
    model.nodes = nodes
    model.set_updates(shared_updates)

    optimizer = Adam(lr=lr)

    monitor_fn = theano.function(inputs=[m_x],
                                 outputs=[
                                     m_recon_term, max_theta_sig,
                                     mean_theta_sig, min_theta_sig, max_x,
                                     mean_x, min_x, max_theta_mu,
                                     mean_theta_mu, min_theta_mu
                                 ],
                                 on_unused_input='ignore')

    extension = [
        GradientClipping(batch_size=batch_size, check_nan=1),
        EpochCount(epoch),
        Monitoring(freq=monitoring_freq,
                   monitor_fn=monitor_fn,
                   ddout=[
                       m_recon_term, max_theta_sig, mean_theta_sig,
                       min_theta_sig, max_x, mean_x, min_x, max_theta_mu,
                       mean_theta_mu, min_theta_mu
                   ],
                   data=[
                       Iterator(train_data, m_batch_size, start=0, end=112640),
                       Iterator(valid_data,
                                m_batch_size,
                                start=2040064,
                                end=2152704)
                   ]),
        Picklize(freq=monitoring_freq,
                 force_save_freq=force_saving_freq,
                 path=save_path),
        EarlyStopping(freq=monitoring_freq,
                      force_save_freq=force_saving_freq,
                      path=save_path,
                      channel=channel_name),
        WeightNorm()
    ]

    mainloop = Training(name=pkl_name,
                        data=Iterator(train_data,
                                      batch_size,
                                      start=0,
                                      end=2040064),
                        model=model,
                        optimizer=optimizer,
                        cost=recon_term,
                        outputs=[recon_term],
                        extension=extension)
    mainloop.run()
Exemple #5
0
def main(args):

    trial = int(args['trial'])
    pkl_name = 'rnn_gauss_%d' % trial
    channel_name = 'valid_nll'

    data_path = args['data_path']
    save_path = args['save_path']
    flgMSE = int(args['flgMSE'])

    monitoring_freq = int(args['monitoring_freq'])
    epoch = int(args['epoch'])
    batch_size = int(args['batch_size'])
    x_dim = int(args['x_dim'])
    z_dim = int(args['z_dim'])
    y_dim = int(args['y_dim'])
    flgAgg = int(args['flgAgg'])
    rnn_dim = int(args['rnn_dim'])
    lr = float(args['lr'])
    debug = int(args['debug'])

    print "trial no. %d" % trial
    print "batch size %d" % batch_size
    print "learning rate %f" % lr
    print "saving pkl file '%s'" % pkl_name
    print "to the save path '%s'" % save_path

    x2s_dim = 340
    s2x_dim = 340
    target_dim = k  #x_dim - 1

    model = Model()
    train_data = UKdale(name='train',
                        prep='normalize',
                        cond=False,
                        path=data_path,
                        windows=windows,
                        appliances=appliances,
                        numApps=flgAgg,
                        period=period,
                        n_steps=n_steps,
                        stride_train=stride_train,
                        stride_test=stride_test)

    X_mean = train_data.X_mean
    X_std = train_data.X_std

    valid_data = UKdale(name='valid',
                        prep='normalize',
                        cond=False,
                        path=data_path,
                        X_mean=X_mean,
                        X_std=X_std,
                        windows=windows,
                        appliances=appliances,
                        numApps=flgAgg,
                        period=period,
                        n_steps=n_steps,
                        stride_train=stride_train,
                        stride_test=stride_test)

    init_W = InitCell('rand')
    init_U = InitCell('ortho')
    init_b = InitCell('zeros')
    init_b_sig = InitCell('const', mean=0.6)

    x, y = train_data.theano_vars()

    if debug:
        x.tag.test_value = np.zeros((15, batch_size, x_dim), dtype=np.float32)
        temp = np.ones((15, batch_size), dtype=np.float32)
        temp[:, -2:] = 0.
        mask.tag.test_value = temp

    x_1 = FullyConnectedLayer(name='x_1',
                              parent=['x_t'],
                              parent_dim=[x_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    rnn = LSTM(name='rnn',
               parent=['x_1'],
               parent_dim=[x2s_dim],
               nout=rnn_dim,
               unit='tanh',
               init_W=init_W,
               init_U=init_U,
               init_b=init_b)

    theta_1 = FullyConnectedLayer(name='theta_1',
                                  parent=['s_tm1'],
                                  parent_dim=[rnn_dim],
                                  nout=s2x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_mu = FullyConnectedLayer(name='theta_mu',
                                   parent=['theta_1'],
                                   parent_dim=[s2x_dim],
                                   nout=target_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    theta_sig = FullyConnectedLayer(name='theta_sig',
                                    parent=['theta_1'],
                                    parent_dim=[s2x_dim],
                                    nout=target_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    corr = FullyConnectedLayer(name='corr',
                               parent=['theta_1'],
                               parent_dim=[s2x_dim],
                               nout=1,
                               unit='tanh',
                               init_W=init_W,
                               init_b=init_b)

    binary = FullyConnectedLayer(name='binary',
                                 parent=['theta_1'],
                                 parent_dim=[s2x_dim],
                                 nout=1,
                                 unit='sigmoid',
                                 init_W=init_W,
                                 init_b=init_b)

    nodes = [rnn, x_1, theta_1, theta_mu, theta_sig]  #, corr, binary

    params = OrderedDict()

    for node in nodes:
        if node.initialize() is not None:
            params.update(node.initialize())

    params = init_tparams(params)

    s_0 = rnn.get_init_state(batch_size)

    x_1_temp = x_1.fprop([x], params)

    def inner_fn(x_t, s_tm1):

        s_t = rnn.fprop([[x_t], [s_tm1]], params)
        theta_1_t = theta_1.fprop([s_t], params)
        theta_mu_t = theta_mu.fprop([theta_1_t], params)
        theta_sig_t = theta_sig.fprop([theta_1_t], params)
        coeff_t = coeff.fprop([theta_1_t], params)

        pred = Gaussian_sample(theta_mu_t, theta_sig_t)
        return s_t, theta_mu_t, theta_sig_t, coeff_t, pred

    ((s_temp, theta_mu_temp, theta_sig_temp, coeff_temp, pred_temp),
     updates) = theano.scan(fn=inner_fn,
                            sequences=[x_1_temp],
                            outputs_info=[s_0, None, None, None, None])

    for k, v in updates.iteritems():
        k.default_update = v

    s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)
    '''
    theta_1_temp = theta_1.fprop([s_temp], params)
    theta_mu_temp = theta_mu.fprop([theta_1_temp], params)
    theta_sig_temp = theta_sig.fprop([theta_1_temp], params)
    corr_temp = corr.fprop([theta_1_temp], params)
    binary_temp = binary.fprop([theta_1_temp], params)
    '''
    x_shape = x.shape
    x_in = x.reshape((x_shape[0] * x_shape[1], -1))
    theta_mu_in = theta_mu_temp.reshape((x_shape[0] * x_shape[1], -1))
    theta_sig_in = theta_sig_temp.reshape((x_shape[0] * x_shape[1], -1))
    corr_in = corr_temp.reshape((x_shape[0] * x_shape[1], -1))
    binary_in = binary_temp.reshape((x_shape[0] * x_shape[1], -1))

    if (flgAgg == -1):
        prediction.name = 'x_reconstructed'
        mse = T.mean((prediction - x)**2)  # CHECK RESHAPE with an assertion
        mae = T.mean(T.abs(prediction - x))
        mse.name = 'mse'
        pred_in = x.reshape((x_shape[0] * x_shape[1], -1))
    else:
        pred_temp = pred_temp.reshape((pred_temp.shape[0], pred_temp.shape[1]))
        pred_temp.name = 'pred_' + str(flgAgg)
        #y[:,:,flgAgg].reshape((y.shape[0],y.shape[1],1))
        mse = T.mean((pred_temp - y.T)**2)  # CHECK RESHAPE with an assertion
        mae = T.mean(T.abs_(pred_temp - y.T))
        mse.name = 'mse'
        mae.name = 'mae'
        pred_in = y.reshape((x.shape[0] * x.shape[1], -1), ndim=2)

    recon = Gaussian(pred_in, theta_mu_in, theta_sig_in)
    recon = recon.reshape((x_shape[0], x_shape[1]))
    #recon = recon * mask
    recon_term = recon.sum(axis=0).mean()
    recon_term.name = 'nll'

    max_x = x.max()
    mean_x = x.mean()
    min_x = x.min()
    max_x.name = 'max_x'
    mean_x.name = 'mean_x'
    min_x.name = 'min_x'

    max_theta_mu = theta_mu_in.max()
    mean_theta_mu = theta_mu_in.mean()
    min_theta_mu = theta_mu_in.min()
    max_theta_mu.name = 'max_theta_mu'
    mean_theta_mu.name = 'mean_theta_mu'
    min_theta_mu.name = 'min_theta_mu'

    max_theta_sig = theta_sig_in.max()
    mean_theta_sig = theta_sig_in.mean()
    min_theta_sig = theta_sig_in.min()
    max_theta_sig.name = 'max_theta_sig'
    mean_theta_sig.name = 'mean_theta_sig'
    min_theta_sig.name = 'min_theta_sig'

    model.inputs = [x, y]
    model.params = params
    model.nodes = nodes

    optimizer = Adam(lr=lr)

    extension = [
        GradientClipping(batch_size=batch_size),
        EpochCount(epoch),
        Monitoring(freq=monitoring_freq,
                   ddout=[
                       recon_term, max_theta_sig, mean_theta_sig,
                       min_theta_sig, max_x, mean_x, min_x, max_theta_mu,
                       mean_theta_mu, min_theta_mu
                   ],
                   data=[Iterator(valid_data, batch_size)]),
        Picklize(freq=monitoring_freq, path=save_path),
        EarlyStopping(freq=monitoring_freq,
                      path=save_path,
                      channel=channel_name),
        WeightNorm()
    ]

    mainloop = Training(name=pkl_name,
                        data=Iterator(train_data, batch_size),
                        model=model,
                        optimizer=optimizer,
                        cost=recon_term,
                        outputs=[recon_term],
                        extension=extension)
    mainloop.run()
    fLog = open(save_path + '/output.csv', 'w')
    fLog.write("log,mse,mae\n")
    for i, item in enumerate(mainloop.trainlog.monitor['nll_upper_bound']):
        a = mainloop.trainlog.monitor['recon_term'][i]
        d = mainloop.trainlog.monitor['mse'][i]
        e = mainloop.trainlog.monitor['mae'][i]
        fLog.write("{},{},{}\n".format(a, d, e))
Exemple #6
0
def main(args):

    theano.optimizer='fast_compile'
    theano.config.exception_verbosity='high'
    trial = int(args['trial'])
    pkl_name = 'vrnn_gauss_%d' % trial
    channel_name = 'valid_nll_upper_bound'

    data_path = args['data_path']
    save_path = args['save_path']
    save_path = args['save_path']
    period = int(args['period'])
    n_steps = int(args['n_steps'])
    stride_train = int(args['stride_train'])
    stride_test = int(args['stride_test'])

    monitoring_freq = int(args['monitoring_freq'])
    epoch = int(args['epoch'])
    batch_size = int(args['batch_size'])
    x_dim = int(args['x_dim'])
    z_dim = int(args['z_dim'])
    rnn_dim = int(args['rnn_dim'])
    lr = float(args['lr'])
    debug = int(args['debug'])

    print "trial no. %d" % trial
    print "batch size %d" % batch_size
    print "learning rate %f" % lr
    print "saving pkl file '%s'" % pkl_name
    print "to the save path '%s'" % save_path


    q_z_dim = 150
    p_z_dim = 150
    p_x_dim = 250
    x2s_dim = 10#250
    z2s_dim = 10#150
    target_dim = x_dim#(x_dim-1)

    model = Model()
    train_data = UKdale(name='train',
                         prep='none', #normalize
                         cond=False,
                         path=data_path,
                         period= period,
                         n_steps = n_steps,
                         x_dim=x_dim,
                         stride_train = stride_train,
                         stride_test = stride_test)

    X_mean = train_data.X_mean
    X_std = train_data.X_std

    valid_data = UKdale(name='valid',
                         prep='none', #normalize
                         cond=False,
                         path=data_path,
                         X_mean=X_mean,
                         X_std=X_std)

    init_W = InitCell('rand')
    init_U = InitCell('ortho')
    init_b = InitCell('zeros')
    init_b_sig = InitCell('const', mean=0.6)

    x, mask = train_data.theano_vars()

    if debug:
        x.tag.test_value = np.zeros((15, batch_size, x_dim), dtype=np.float32)
        temp = np.ones((15, batch_size), dtype=np.float32)
        temp[:, -2:] = 0.
        mask.tag.test_value = temp

    x_1 = FullyConnectedLayer(name='x_1',
                              parent=['x_t'],
                              parent_dim=[x_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_1 = FullyConnectedLayer(name='z_1',
                              parent=['z_t'],
                              parent_dim=[z_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    rnn = LSTM(name='rnn',
               parent=['x_1', 'z_1'],
               parent_dim=[x2s_dim, z2s_dim],
               nout=rnn_dim,
               unit='tanh',
               init_W=init_W,
               init_U=init_U,
               init_b=init_b)

    phi_1 = FullyConnectedLayer(name='phi_1', ## encoder
                                parent=['x_1', 's_tm1'],
                                parent_dim=[x2s_dim, rnn_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_mu = FullyConnectedLayer(name='phi_mu',
                                 parent=['phi_1'],
                                 parent_dim=[q_z_dim],
                                 nout=z_dim,
                                 unit='linear',
                                 init_W=init_W,
                                 init_b=init_b)

    phi_sig = FullyConnectedLayer(name='phi_sig',
                                  parent=['phi_1'],
                                  parent_dim=[q_z_dim],
                                  nout=z_dim,
                                  unit='softplus',
                                  cons=1e-4,
                                  init_W=init_W,
                                  init_b=init_b_sig)

    prior_1 = FullyConnectedLayer(name='prior_1',
                                  parent=['s_tm1'],
                                  parent_dim=[rnn_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_mu = FullyConnectedLayer(name='prior_mu',
                                   parent=['prior_1'],
                                   parent_dim=[p_z_dim],
                                   nout=z_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    prior_sig = FullyConnectedLayer(name='prior_sig',
                                    parent=['prior_1'],
                                    parent_dim=[p_z_dim],
                                    nout=z_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    theta_1 = FullyConnectedLayer(name='theta_1', ### decoder
                                  parent=['z_1', 's_tm1'],
                                  parent_dim=[z2s_dim, rnn_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_mu = FullyConnectedLayer(name='theta_mu',
                                   parent=['theta_1'],
                                   parent_dim=[p_x_dim],
                                   nout=target_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    theta_sig = FullyConnectedLayer(name='theta_sig',
                                    parent=['theta_1'],
                                    parent_dim=[p_x_dim],
                                    nout=target_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    corr = FullyConnectedLayer(name='corr',  ## rho
                               parent=['theta_1'],
                               parent_dim=[p_x_dim],
                               nout=1,
                               unit='tanh',
                               init_W=init_W,
                               init_b=init_b)

    binary = FullyConnectedLayer(name='binary',
                                 parent=['theta_1'],
                                 parent_dim=[p_x_dim],
                                 nout=1,
                                 unit='sigmoid',
                                 init_W=init_W,
                                 init_b=init_b)

    nodes = [rnn,
             x_1, z_1,
             phi_1, phi_mu, phi_sig,
             prior_1, prior_mu, prior_sig,
             theta_1, theta_mu, theta_sig] #, corr, binary

    params = OrderedDict()

    for node in nodes:
        if node.initialize() is not None:
            params.update(node.initialize()) #Initialize values of the W matrices according to dim of parents

    params = init_tparams(params)

    s_0 = rnn.get_init_state(batch_size)

    x_1_temp = x_1.fprop([x], params)


    def inner_fn(x_t, s_tm1):

        phi_1_t = phi_1.fprop([x_t, s_tm1], params)
        phi_mu_t = phi_mu.fprop([phi_1_t], params)
        phi_sig_t = phi_sig.fprop([phi_1_t], params)

        prior_1_t = prior_1.fprop([s_tm1], params)
        prior_mu_t = prior_mu.fprop([prior_1_t], params)
        prior_sig_t = prior_sig.fprop([prior_1_t], params)

        z_t = Gaussian_sample(phi_mu_t, phi_sig_t)
        z_1_t = z_1.fprop([z_t], params)

        s_t = rnn.fprop([[x_t, z_1_t], [s_tm1]], params)

        return s_t, phi_mu_t, phi_sig_t, prior_mu_t, prior_sig_t, z_1_t

    ((s_temp, phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp, z_1_temp), updates) =\
        theano.scan(fn=inner_fn,
                    sequences=[x_1_temp],
                    outputs_info=[s_0, None, None, None, None, None])

    for k, v in updates.iteritems():
        print("Update")
        k.default_update = v

    s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)
    s_temp.name = 'h_1'
    z_1_temp.name = 'z_1'
    theta_1_temp = theta_1.fprop([z_1_temp, s_temp], params)
    theta_mu_temp = theta_mu.fprop([theta_1_temp], params)
    theta_mu_temp.name = 'theta_mu'
    theta_sig_temp = theta_sig.fprop([theta_1_temp], params)
    theta_sig_temp.name = 'theta_sig'
    #corr_temp = corr.fprop([theta_1_temp], params)
    #corr_temp.name = 'corr'
    #binary_temp = binary.fprop([theta_1_temp], params)
    #binary_temp.name = 'binary'

    kl_temp = KLGaussianGaussian(phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp)

    x_shape = x.shape
    x_in = x.reshape((x_shape[0]*x_shape[1], -1))
    theta_mu_in = theta_mu_temp.reshape((x_shape[0]*x_shape[1], -1))
    theta_sig_in = theta_sig_temp.reshape((x_shape[0]*x_shape[1], -1))
    #corr_in = corr_temp.reshape((x_shape[0]*x_shape[1], -1))
    #binary_in = binary_temp.reshape((x_shape[0]*x_shape[1], -1))

    recon = Gaussian(x_in, theta_mu_in, theta_sig_in) # BiGauss(x_in, theta_mu_in, theta_sig_in, corr_in, binary_in) # second term for the loss function
    recon = recon.reshape((x_shape[0], x_shape[1]))
    #recon = recon * mask
    recon_term = recon.sum(axis=0).mean()
    recon_term.name = 'recon_term'

    #kl_temp = kl_temp * mask
    kl_term = kl_temp.sum(axis=0).mean()
    kl_term.name = 'kl_term'

    nll_upper_bound = recon_term + kl_term
    nll_upper_bound.name = 'nll_upper_bound'

    max_x = x.max()
    mean_x = x.mean()
    min_x = x.min()
    max_x.name = 'max_x'
    mean_x.name = 'mean_x'
    min_x.name = 'min_x'

    max_theta_mu = theta_mu_in.max()
    mean_theta_mu = theta_mu_in.mean()
    min_theta_mu = theta_mu_in.min()
    max_theta_mu.name = 'max_theta_mu'
    mean_theta_mu.name = 'mean_theta_mu'
    min_theta_mu.name = 'min_theta_mu'

    max_theta_sig = theta_sig_in.max()
    mean_theta_sig = theta_sig_in.mean()
    min_theta_sig = theta_sig_in.min()
    max_theta_sig.name = 'max_theta_sig'
    mean_theta_sig.name = 'mean_theta_sig'
    min_theta_sig.name = 'min_theta_sig'

    max_phi_sig = phi_sig_temp.max()
    mean_phi_sig = phi_sig_temp.mean()
    min_phi_sig = phi_sig_temp.min()
    max_phi_sig.name = 'max_phi_sig'
    mean_phi_sig.name = 'mean_phi_sig'
    min_phi_sig.name = 'min_phi_sig'

    max_prior_sig = prior_sig_temp.max()
    mean_prior_sig = prior_sig_temp.mean()
    min_prior_sig = prior_sig_temp.min()
    max_prior_sig.name = 'max_prior_sig'
    mean_prior_sig.name = 'mean_prior_sig'
    min_prior_sig.name = 'min_prior_sig'

    prior_sig_output = prior_sig_temp
    prior_sig_output.name = 'prior_sig_o'
    phi_sig_output = phi_sig_temp
    phi_sig_output.name = 'phi_sig_o'

    model.inputs = [x, mask]
    model.params = params
    model.nodes = nodes

    optimizer = Adam(
        lr=lr
    )

    extension = [
        GradientClipping(batch_size=batch_size),
        EpochCount(epoch),
        Monitoring(freq=monitoring_freq,
                   ddout=[nll_upper_bound, recon_term, kl_term,
                          max_phi_sig, mean_phi_sig, min_phi_sig,
                          max_prior_sig, mean_prior_sig, min_prior_sig,
                          max_theta_sig, mean_theta_sig, min_theta_sig,
                          max_x, mean_x, min_x,
                          max_theta_mu, mean_theta_mu, min_theta_mu, #0-16
                          #binary_temp, corr_temp, 
                          theta_mu_temp, theta_sig_temp, #17-20
                          s_temp, z_1_temp
                          #phi_sig_output,phi_sig_output
                          ],## added in order to explore the distributions
                   data=[Iterator(valid_data, batch_size)]),
        Picklize(freq=monitoring_freq, path=save_path),
        EarlyStopping(freq=monitoring_freq, path=save_path, channel=channel_name),
        WeightNorm()
    ]

    mainloop = Training(
        name=pkl_name,
        data=Iterator(train_data, batch_size),
        model=model,
        optimizer=optimizer,
        cost=nll_upper_bound,
        outputs=[nll_upper_bound],
        extension=extension
    )
    mainloop.run()
Exemple #7
0
 def cost(self, readouts, outputs):
     mu, sigma = self.components(readouts)
     return Gaussian(outputs, mu, sigma)
def main(args):

    trial = int(args['trial'])
    pkl_name = 'vrnn_gauss_%d' % trial
    channel_name = 'valid_nll_upper_bound'

    data_path = args['data_path']
    save_path = args['save_path']

    monitoring_freq = int(args['monitoring_freq'])
    force_saving_freq = int(args['force_saving_freq'])
    reset_freq = int(args['reset_freq'])
    epoch = int(args['epoch'])
    batch_size = int(args['batch_size'])
    m_batch_size = int(args['m_batch_size'])
    x_dim = int(args['x_dim'])
    z_dim = int(args['z_dim'])
    rnn_dim = int(args['rnn_dim'])
    lr = float(args['lr'])
    debug = int(args['debug'])

    print "trial no. %d" % trial
    print "batch size %d" % batch_size
    print "learning rate %f" % lr
    print "saving pkl file '%s'" % pkl_name
    print "to the save path '%s'" % save_path

    q_z_dim = 500
    p_z_dim = 500
    p_x_dim = 600
    x2s_dim = 600
    z2s_dim = 500
    target_dim = x_dim

    file_name = 'blizzard_unseg_tbptt'
    normal_params = np.load(data_path + file_name + '_normal.npz')
    X_mean = normal_params['X_mean']
    X_std = normal_params['X_std']

    model = Model()
    train_data = Blizzard_tbptt(name='train',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    valid_data = Blizzard_tbptt(name='valid',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    x = train_data.theano_vars()
    m_x = valid_data.theano_vars()

    if debug:
        x.tag.test_value = np.zeros((15, batch_size, x_dim),
                                    dtype=theano.config.floatX)
        m_x.tag.test_value = np.zeros((15, m_batch_size, x_dim),
                                      dtype=theano.config.floatX)

    init_W = InitCell('rand')
    init_U = InitCell('ortho')
    init_b = InitCell('zeros')
    init_b_sig = InitCell('const', mean=0.6)

    x_1 = FullyConnectedLayer(name='x_1',
                              parent=['x_t'],
                              parent_dim=[x_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_2 = FullyConnectedLayer(name='x_2',
                              parent=['x_1'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_3 = FullyConnectedLayer(name='x_3',
                              parent=['x_2'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_4 = FullyConnectedLayer(name='x_4',
                              parent=['x_3'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_1 = FullyConnectedLayer(name='z_1',
                              parent=['z_t'],
                              parent_dim=[z_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_2 = FullyConnectedLayer(name='z_2',
                              parent=['z_1'],
                              parent_dim=[z2s_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_3 = FullyConnectedLayer(name='z_3',
                              parent=['z_2'],
                              parent_dim=[z2s_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_4 = FullyConnectedLayer(name='z_4',
                              parent=['z_3'],
                              parent_dim=[z2s_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    rnn = LSTM(name='rnn',
               parent=['x_4', 'z_4'],
               parent_dim=[x2s_dim, z2s_dim],
               nout=rnn_dim,
               unit='tanh',
               init_W=init_W,
               init_U=init_U,
               init_b=init_b)

    phi_1 = FullyConnectedLayer(name='phi_1',
                                parent=['x_4', 's_tm1'],
                                parent_dim=[x2s_dim, rnn_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_2 = FullyConnectedLayer(name='phi_2',
                                parent=['phi_1'],
                                parent_dim=[q_z_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_3 = FullyConnectedLayer(name='phi_3',
                                parent=['phi_2'],
                                parent_dim=[q_z_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_4 = FullyConnectedLayer(name='phi_4',
                                parent=['phi_3'],
                                parent_dim=[q_z_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_mu = FullyConnectedLayer(name='phi_mu',
                                 parent=['phi_4'],
                                 parent_dim=[q_z_dim],
                                 nout=z_dim,
                                 unit='linear',
                                 init_W=init_W,
                                 init_b=init_b)

    phi_sig = FullyConnectedLayer(name='phi_sig',
                                  parent=['phi_4'],
                                  parent_dim=[q_z_dim],
                                  nout=z_dim,
                                  unit='softplus',
                                  cons=1e-4,
                                  init_W=init_W,
                                  init_b=init_b_sig)

    prior_1 = FullyConnectedLayer(name='prior_1',
                                  parent=['s_tm1'],
                                  parent_dim=[rnn_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_2 = FullyConnectedLayer(name='prior_2',
                                  parent=['prior_1'],
                                  parent_dim=[p_z_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_3 = FullyConnectedLayer(name='prior_3',
                                  parent=['prior_2'],
                                  parent_dim=[p_z_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_4 = FullyConnectedLayer(name='prior_4',
                                  parent=['prior_3'],
                                  parent_dim=[p_z_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_mu = FullyConnectedLayer(name='prior_mu',
                                   parent=['prior_4'],
                                   parent_dim=[p_z_dim],
                                   nout=z_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    prior_sig = FullyConnectedLayer(name='prior_sig',
                                    parent=['prior_4'],
                                    parent_dim=[p_z_dim],
                                    nout=z_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    theta_1 = FullyConnectedLayer(name='theta_1',
                                  parent=['z_4', 's_tm1'],
                                  parent_dim=[z2s_dim, rnn_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_2 = FullyConnectedLayer(name='theta_2',
                                  parent=['theta_1'],
                                  parent_dim=[p_x_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_3 = FullyConnectedLayer(name='theta_3',
                                  parent=['theta_2'],
                                  parent_dim=[p_x_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_4 = FullyConnectedLayer(name='theta_4',
                                  parent=['theta_3'],
                                  parent_dim=[p_x_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_mu = FullyConnectedLayer(name='theta_mu',
                                   parent=['theta_4'],
                                   parent_dim=[p_x_dim],
                                   nout=target_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    theta_sig = FullyConnectedLayer(name='theta_sig',
                                    parent=['theta_4'],
                                    parent_dim=[p_x_dim],
                                    nout=target_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    nodes = [
        rnn, x_1, x_2, x_3, x_4, z_1, z_2, z_3, z_4, phi_1, phi_2, phi_3,
        phi_4, phi_mu, phi_sig, prior_1, prior_2, prior_3, prior_4, prior_mu,
        prior_sig, theta_1, theta_2, theta_3, theta_4, theta_mu, theta_sig
    ]

    params = OrderedDict()

    for node in nodes:
        if node.initialize() is not None:
            params.update(node.initialize())

    params = init_tparams(params)

    step_count = sharedX(0, name='step_count')
    last_rnn = np.zeros((batch_size, rnn_dim * 2), dtype=theano.config.floatX)
    rnn_tm1 = sharedX(last_rnn, name='rnn_tm1')
    shared_updates = OrderedDict()
    shared_updates[step_count] = step_count + 1

    # Resets / Initializes the cell-state or the memory-state of each LSTM to
    # zero.
    s_0 = T.switch(T.eq(T.mod(step_count, reset_freq), 0),
                   rnn.get_init_state(batch_size), rnn_tm1)

    # Forward Propagate the input to get more complex features for
    # every time step.
    x_1_temp = x_1.fprop([x], params)
    x_2_temp = x_2.fprop([x_1_temp], params)
    x_3_temp = x_3.fprop([x_2_temp], params)
    x_4_temp = x_4.fprop([x_3_temp], params)

    def inner_fn(x_t, s_tm1):

        # Generate the mean and standard deviation of the
        # latent variables Z_t | X_t for every time-step of the LSTM.
        # This is a function of the input and the hidden state of the previous
        # time step.
        phi_1_t = phi_1.fprop([x_t, s_tm1], params)
        phi_2_t = phi_2.fprop([phi_1_t], params)
        phi_3_t = phi_3.fprop([phi_2_t], params)
        phi_4_t = phi_4.fprop([phi_3_t], params)
        phi_mu_t = phi_mu.fprop([phi_4_t], params)
        phi_sig_t = phi_sig.fprop([phi_4_t], params)

        # Prior on the latent variables at every time-step
        # Dependent only on the hidden-step.
        prior_1_t = prior_1.fprop([s_tm1], params)
        prior_2_t = prior_2.fprop([prior_1_t], params)
        prior_3_t = prior_3.fprop([prior_2_t], params)
        prior_4_t = prior_4.fprop([prior_3_t], params)
        prior_mu_t = prior_mu.fprop([prior_4_t], params)
        prior_sig_t = prior_sig.fprop([prior_4_t], params)

        # Sample from the latent distibution with mean phi_mu_t
        # and std phi_sig_t
        z_t = Gaussian_sample(phi_mu_t, phi_sig_t)

        # h_t = f(h_(t-1)), z_t, x_t)
        z_1_t = z_1.fprop([z_t], params)
        z_2_t = z_2.fprop([z_1_t], params)
        z_3_t = z_3.fprop([z_2_t], params)
        z_4_t = z_4.fprop([z_3_t], params)

        s_t = rnn.fprop([[x_t, z_4_t], [s_tm1]], params)

        return s_t, phi_mu_t, phi_sig_t, prior_mu_t, prior_sig_t, z_4_t

    # Iterate over every time-step
    ((s_temp, phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp, z_4_temp), updates) =\
        theano.scan(fn=inner_fn,
                    sequences=[x_4_temp],
                    outputs_info=[s_0, None, None, None, None, None])

    for k, v in updates.iteritems():
        k.default_update = v

    shared_updates[rnn_tm1] = s_temp[-1]
    s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)

    # Generate the output distribution at every time-step.
    # This is as a function of the latent variables and the hidden-state at
    # every time-step.
    theta_1_temp = theta_1.fprop([z_4_temp, s_temp], params)
    theta_2_temp = theta_2.fprop([theta_1_temp], params)
    theta_3_temp = theta_3.fprop([theta_2_temp], params)
    theta_4_temp = theta_4.fprop([theta_3_temp], params)
    theta_mu_temp = theta_mu.fprop([theta_4_temp], params)
    theta_sig_temp = theta_sig.fprop([theta_4_temp], params)

    kl_temp = KLGaussianGaussian(phi_mu_temp, phi_sig_temp, prior_mu_temp,
                                 prior_sig_temp)

    recon = Gaussian(x, theta_mu_temp, theta_sig_temp)
    recon_term = recon.mean()
    kl_term = kl_temp.mean()
    nll_upper_bound = recon_term + kl_term
    nll_upper_bound.name = 'nll_upper_bound'

    # Forward-propagation of the validation data.
    m_x_1_temp = x_1.fprop([m_x], params)
    m_x_2_temp = x_2.fprop([m_x_1_temp], params)
    m_x_3_temp = x_3.fprop([m_x_2_temp], params)
    m_x_4_temp = x_4.fprop([m_x_3_temp], params)

    m_s_0 = rnn.get_init_state(m_batch_size)

    # Get the hidden-states, conditional mean, standard deviation, prior mean
    # and prior standard deviation of the latent variables at every time-step.
    ((m_s_temp, m_phi_mu_temp, m_phi_sig_temp, m_prior_mu_temp, m_prior_sig_temp, m_z_4_temp), m_updates) =\
        theano.scan(fn=inner_fn,
                    sequences=[m_x_4_temp],
                    outputs_info=[m_s_0, None, None, None, None, None])

    for k, v in m_updates.iteritems():
        k.default_update = v

    # Get the inferred mean (X_t | Z_t) at every time-step of the validation
    # data.
    m_s_temp = concatenate([m_s_0[None, :, :], m_s_temp[:-1]], axis=0)
    m_theta_1_temp = theta_1.fprop([m_z_4_temp, m_s_temp], params)
    m_theta_2_temp = theta_2.fprop([m_theta_1_temp], params)
    m_theta_3_temp = theta_3.fprop([m_theta_2_temp], params)
    m_theta_4_temp = theta_4.fprop([m_theta_3_temp], params)
    m_theta_mu_temp = theta_mu.fprop([m_theta_4_temp], params)
    m_theta_sig_temp = theta_sig.fprop([m_theta_4_temp], params)

    # Compute the data log-likelihood + KL-divergence on the validation data.
    m_kl_temp = KLGaussianGaussian(m_phi_mu_temp, m_phi_sig_temp,
                                   m_prior_mu_temp, m_prior_sig_temp)

    m_recon = Gaussian(m_x, m_theta_mu_temp, m_theta_sig_temp)
    m_recon_term = m_recon.mean()
    m_kl_term = m_kl_temp.mean()
    m_nll_upper_bound = m_recon_term + m_kl_term
    m_nll_upper_bound.name = 'nll_upper_bound'
    m_recon_term.name = 'recon_term'
    m_kl_term.name = 'kl_term'

    max_x = m_x.max()
    mean_x = m_x.mean()
    min_x = m_x.min()
    max_x.name = 'max_x'
    mean_x.name = 'mean_x'
    min_x.name = 'min_x'

    max_theta_mu = m_theta_mu_temp.max()
    mean_theta_mu = m_theta_mu_temp.mean()
    min_theta_mu = m_theta_mu_temp.min()
    max_theta_mu.name = 'max_theta_mu'
    mean_theta_mu.name = 'mean_theta_mu'
    min_theta_mu.name = 'min_theta_mu'

    max_theta_sig = m_theta_sig_temp.max()
    mean_theta_sig = m_theta_sig_temp.mean()
    min_theta_sig = m_theta_sig_temp.min()
    max_theta_sig.name = 'max_theta_sig'
    mean_theta_sig.name = 'mean_theta_sig'
    min_theta_sig.name = 'min_theta_sig'

    max_phi_sig = m_phi_sig_temp.max()
    mean_phi_sig = m_phi_sig_temp.mean()
    min_phi_sig = m_phi_sig_temp.min()
    max_phi_sig.name = 'max_phi_sig'
    mean_phi_sig.name = 'mean_phi_sig'
    min_phi_sig.name = 'min_phi_sig'

    max_prior_sig = m_prior_sig_temp.max()
    mean_prior_sig = m_prior_sig_temp.mean()
    min_prior_sig = m_prior_sig_temp.min()
    max_prior_sig.name = 'max_prior_sig'
    mean_prior_sig.name = 'mean_prior_sig'
    min_prior_sig.name = 'min_prior_sig'

    model.inputs = [x]
    model.params = params
    model.nodes = nodes
    model.set_updates(shared_updates)

    optimizer = Adam(lr=lr)

    monitor_fn = theano.function(
        inputs=[m_x],
        outputs=[
            m_nll_upper_bound, m_recon_term, m_kl_term, max_phi_sig,
            mean_phi_sig, min_phi_sig, max_prior_sig, mean_prior_sig,
            min_prior_sig, max_theta_sig, mean_theta_sig, min_theta_sig, max_x,
            mean_x, min_x, max_theta_mu, mean_theta_mu, min_theta_mu
        ],
        on_unused_input='ignore')

    extension = [
        GradientClipping(batch_size=batch_size, check_nan=1),
        EpochCount(epoch),
        Monitoring(freq=monitoring_freq,
                   monitor_fn=monitor_fn,
                   ddout=[
                       m_nll_upper_bound, m_recon_term, m_kl_term, max_phi_sig,
                       mean_phi_sig, min_phi_sig, max_prior_sig,
                       mean_prior_sig, min_prior_sig, max_theta_sig,
                       mean_theta_sig, min_theta_sig, max_x, mean_x, min_x,
                       max_theta_mu, mean_theta_mu, min_theta_mu
                   ],
                   data=[
                       Iterator(train_data, m_batch_size, start=0, end=112640),
                       Iterator(valid_data,
                                m_batch_size,
                                start=2040064,
                                end=2152704)
                   ]),
        Picklize(freq=monitoring_freq,
                 force_save_freq=force_saving_freq,
                 path=save_path),
        EarlyStopping(freq=monitoring_freq,
                      force_save_freq=force_saving_freq,
                      path=save_path,
                      channel=channel_name),
        WeightNorm()
    ]

    mainloop = Training(name=pkl_name,
                        data=Iterator(train_data,
                                      batch_size,
                                      start=0,
                                      end=2040064),
                        model=model,
                        optimizer=optimizer,
                        cost=nll_upper_bound,
                        outputs=[nll_upper_bound],
                        extension=extension)
    mainloop.run()
def main(args):

    theano.optimizer = 'fast_compile'
    theano.config.exception_verbosity = 'high'
    trial = int(args['trial'])
    pkl_name = 'vrnn_gauss_%d' % trial
    channel_name = 'valid_nll_upper_bound'

    data_path = args['data_path']
    save_path = args['save_path']
    save_path = args['save_path']
    period = int(args['period'])
    n_steps = int(args['n_steps'])
    stride_train = int(args['stride_train'])
    stride_test = int(args['stride_test'])

    monitoring_freq = int(args['monitoring_freq'])
    epoch = int(args['epoch'])
    batch_size = int(args['batch_size'])
    x_dim = int(args['x_dim'])
    z_dim = int(args['z_dim'])
    rnn_dim = int(args['rnn_dim'])
    lr = float(args['lr'])
    debug = int(args['debug'])

    print "trial no. %d" % trial
    print "batch size %d" % batch_size
    print "learning rate %f" % lr
    print "saving pkl file '%s'" % pkl_name
    print "to the save path '%s'" % save_path

    q_z_dim = 150
    p_z_dim = 150
    p_x_dim = 250
    x2s_dim = 10  #250
    z2s_dim = 10  #150
    target_dim = x_dim  #(x_dim-1)

    model = Model()
    Xtrain, ytrain, Xval, yval = fetch_ukdale(data_path,
                                              windows,
                                              appliances,
                                              numApps=flgAgg,
                                              period=period,
                                              n_steps=n_steps,
                                              stride_train=stride_train,
                                              stride_test=stride_test)

    train_data = UKdale(
        name='train',
        prep='normalize',
        cond=True,  # False
        #path=data_path,
        inputX=Xtrain,
        labels=ytrain)

    X_mean = train_data.X_mean
    X_std = train_data.X_std

    valid_data = UKdale(
        name='valid',
        prep='normalize',
        cond=True,  # False
        #path=data_path,
        X_mean=X_mean,
        X_std=X_std,
        inputX=Xval,
        labels=yval)

    init_W = InitCell('rand')
    init_U = InitCell('ortho')
    init_b = InitCell('zeros')
    init_b_sig = InitCell('const', mean=0.6)

    x, y = train_data.theano_vars()

    if debug:
        x.tag.test_value = np.zeros((15, batch_size, x_dim), dtype=np.float32)
        temp = np.ones((15, batch_size), dtype=np.float32)
        temp[:, -2:] = 0.
        mask.tag.test_value = temp

    x_1 = FullyConnectedLayer(
        name='x_1',
        parent=['x_t'],  #OrderDict parent['x_t'] = x_dim
        parent_dim=[x_dim],
        nout=x2s_dim,
        unit='relu',
        init_W=init_W,
        init_b=init_b)

    z_1 = FullyConnectedLayer(name='z_1',
                              parent=['z_t'],
                              parent_dim=[z_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    rnn = LSTM(name='rnn',
               parent=['x_1', 'z_1'],
               parent_dim=[x2s_dim, z2s_dim],
               nout=rnn_dim,
               unit='tanh',
               init_W=init_W,
               init_U=init_U,
               init_b=init_b)

    phi_1 = FullyConnectedLayer(
        name='phi_1',  ## encoder
        parent=['x_1', 's_tm1'],
        parent_dim=[x2s_dim, rnn_dim],
        nout=q_z_dim,
        unit='relu',
        init_W=init_W,
        init_b=init_b)

    phi_mu = FullyConnectedLayer(name='phi_mu',
                                 parent=['phi_1'],
                                 parent_dim=[q_z_dim],
                                 nout=z_dim,
                                 unit='linear',
                                 init_W=init_W,
                                 init_b=init_b)

    phi_sig = FullyConnectedLayer(name='phi_sig',
                                  parent=['phi_1'],
                                  parent_dim=[q_z_dim],
                                  nout=z_dim,
                                  unit='softplus',
                                  cons=1e-4,
                                  init_W=init_W,
                                  init_b=init_b_sig)

    prior_1 = FullyConnectedLayer(name='prior_1',
                                  parent=['s_tm1'],
                                  parent_dim=[rnn_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_mu = FullyConnectedLayer(name='prior_mu',
                                   parent=['prior_1'],
                                   parent_dim=[p_z_dim],
                                   nout=z_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    prior_sig = FullyConnectedLayer(name='prior_sig',
                                    parent=['prior_1'],
                                    parent_dim=[p_z_dim],
                                    nout=z_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    theta_1 = FullyConnectedLayer(
        name='theta_1',  ### decoder
        parent=['z_1', 's_tm1'],
        parent_dim=[z2s_dim, rnn_dim],
        nout=p_x_dim,
        unit='relu',
        init_W=init_W,
        init_b=init_b)

    theta_mu = FullyConnectedLayer(name='theta_mu',
                                   parent=['theta_1'],
                                   parent_dim=[p_x_dim],
                                   nout=target_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    theta_sig = FullyConnectedLayer(name='theta_sig',
                                    parent=['theta_1'],
                                    parent_dim=[p_x_dim],
                                    nout=target_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    corr = FullyConnectedLayer(
        name='corr',  ## rho
        parent=['theta_1'],
        parent_dim=[p_x_dim],
        nout=1,
        unit='tanh',
        init_W=init_W,
        init_b=init_b)

    binary = FullyConnectedLayer(name='binary',
                                 parent=['theta_1'],
                                 parent_dim=[p_x_dim],
                                 nout=1,
                                 unit='sigmoid',
                                 init_W=init_W,
                                 init_b=init_b)

    nodes = [
        rnn, x_1, z_1, phi_1, phi_mu, phi_sig, prior_1, prior_mu, prior_sig,
        theta_1, theta_mu, theta_sig
    ]  #, corr, binary

    params = OrderedDict()

    for node in nodes:
        if node.initialize() is not None:
            params.update(
                node.initialize()
            )  #Initialize values of the W matrices according to dim of parents

    params = init_tparams(params)

    s_0 = rnn.get_init_state(batch_size)

    x_1_temp = x_1.fprop([x], params)

    def inner_fn(x_t, s_tm1):

        phi_1_t = phi_1.fprop([x_t, s_tm1], params)
        phi_mu_t = phi_mu.fprop([phi_1_t], params)
        phi_sig_t = phi_sig.fprop([phi_1_t], params)

        prior_1_t = prior_1.fprop([s_tm1], params)
        prior_mu_t = prior_mu.fprop([prior_1_t], params)
        prior_sig_t = prior_sig.fprop([prior_1_t], params)

        z_t = Gaussian_sample(phi_mu_t, phi_sig_t)
        z_1_t = z_1.fprop([z_t], params)

        theta_1_t = theta_1.fprop([z_1_t, s_tm1], params)
        theta_mu_t = theta_mu.fprop([theta_1_t], params)
        theta_sig_t = theta_sig.fprop([theta_1_t], params)

        pred = Gaussian_sample(theta_mu_t, theta_sig_t)

        s_t = rnn.fprop([[x_t, z_1_t], [s_tm1]], params)

        return s_t, phi_mu_t, phi_sig_t, prior_mu_t, prior_sig_t, z_t, z_1_t, theta_1_t, theta_mu_t, theta_sig_t, pred

    ((s_temp, phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp, z_temp, z_1_temp, theta_1_temp, theta_mu_temp, theta_sig_temp, pred_temp), updates) =\
        theano.scan(fn=inner_fn,
                    sequences=[x_1_temp], #non_sequences unchanging variables
                    #The tensor(s) to be looped over should be provided to scan using the sequence keyword argument
                    outputs_info=[s_0, None, None, None, None, None, None, None, None, None, None])#Initialization occurs in outputs_info
    #=None This indicates to scan that it does not need to pass the prior result to _fn
    '''
    The general order of function parameters to:
    sequences (if any), prior result(s) (if needed), non-sequences (if any)
    '''
    for k, v in updates.iteritems():
        print("Update")
        k.default_update = v

    s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)
    s_temp.name = 'h_1'  #gisse
    z_temp.name = 'z'
    z_1_temp.name = 'z_1'  #gisse
    #theta_1_temp = theta_1.fprop([z_1_temp, s_temp], params)
    #theta_mu_temp = theta_mu.fprop([theta_1_temp], params)
    theta_mu_temp.name = 'theta_mu'
    #theta_sig_temp = theta_sig.fprop([theta_1_temp], params)
    theta_sig_temp.name = 'theta_sig'
    x_pred_temp.name = 'x_reconstructed'
    #corr_temp = corr.fprop([theta_1_temp], params)
    #corr_temp.name = 'corr'
    #binary_temp = binary.fprop([theta_1_temp], params)
    #binary_temp.name = 'binary'

    if (flgAgg == -1):
        prediction.name = 'x_reconstructed'
        mse = T.mean((prediction - x)**2)  # CHECK RESHAPE with an assertion
        mae = T.mean(T.abs(prediction - x))
        mse.name = 'mse'
        pred_in = x.reshape((x_shape[0] * x_shape[1], -1))
    else:
        prediction.name = 'pred_' + str(flgAgg)
        mse = T.mean((prediction - y[:, :, flgAgg].reshape(
            (y.shape[0], y.shape[1],
             1)))**2)  # CHECK RESHAPE with an assertion
        mae = T.mean(
            T.abs_(prediction -
                   y[:, :, flgAgg].reshape((y.shape[0], y.shape[1], 1))))
        mse.name = 'mse'
        mae.name = 'mae'
        pred_in = y[:, :, flgAgg].reshape((x.shape[0] * x.shape[1], -1),
                                          ndim=2)

    kl_temp = KLGaussianGaussian(phi_mu_temp, phi_sig_temp, prior_mu_temp,
                                 prior_sig_temp)

    #x_shape = x.shape
    #x_in = x.reshape((x_shape[0]*x_shape[1], -1))
    theta_mu_in = theta_mu_temp.reshape((x_shape[0] * x_shape[1], -1))
    theta_sig_in = theta_sig_temp.reshape((x_shape[0] * x_shape[1], -1))
    #corr_in = corr_temp.reshape((x_shape[0]*x_shape[1], -1))
    #binary_in = binary_temp.reshape((x_shape[0]*x_shape[1], -1))

    recon = Gaussian(
        pred_in, theta_mu_in, theta_sig_in
    )  # BiGauss(x_in, theta_mu_in, theta_sig_in, corr_in, binary_in) # second term for the loss function
    recon = recon.reshape((x_shape[0], x_shape[1]))
    #recon = recon * mask
    recon_term = recon.sum(axis=0).mean()
    recon_term.name = 'recon_term'

    #kl_temp = kl_temp * mask
    kl_term = kl_temp.sum(axis=0).mean()
    kl_term.name = 'kl_term'

    nll_upper_bound = recon_term + kl_term
    nll_upper_bound.name = 'nll_upper_bound'

    max_x = x.max()
    mean_x = x.mean()
    min_x = x.min()
    max_x.name = 'max_x'
    mean_x.name = 'mean_x'
    min_x.name = 'min_x'

    max_theta_mu = theta_mu_in.max()
    mean_theta_mu = theta_mu_in.mean()
    min_theta_mu = theta_mu_in.min()
    max_theta_mu.name = 'max_theta_mu'
    mean_theta_mu.name = 'mean_theta_mu'
    min_theta_mu.name = 'min_theta_mu'

    max_theta_sig = theta_sig_in.max()
    mean_theta_sig = theta_sig_in.mean()
    min_theta_sig = theta_sig_in.min()
    max_theta_sig.name = 'max_theta_sig'
    mean_theta_sig.name = 'mean_theta_sig'
    min_theta_sig.name = 'min_theta_sig'

    max_phi_sig = phi_sig_temp.max()
    mean_phi_sig = phi_sig_temp.mean()
    min_phi_sig = phi_sig_temp.min()
    max_phi_sig.name = 'max_phi_sig'
    mean_phi_sig.name = 'mean_phi_sig'
    min_phi_sig.name = 'min_phi_sig'

    max_prior_sig = prior_sig_temp.max()
    mean_prior_sig = prior_sig_temp.mean()
    min_prior_sig = prior_sig_temp.min()
    max_prior_sig.name = 'max_prior_sig'
    mean_prior_sig.name = 'mean_prior_sig'
    min_prior_sig.name = 'min_prior_sig'

    prior_sig_output = prior_sig_temp
    prior_sig_output.name = 'prior_sig_o'
    phi_sig_output = phi_sig_temp
    phi_sig_output.name = 'phi_sig_o'

    model.inputs = [x, mask]
    model.params = params
    model.nodes = nodes

    optimizer = Adam(lr=lr)

    extension = [
        GradientClipping(batch_size=batch_size),
        EpochCount(epoch),
        Monitoring(
            freq=monitoring_freq,
            ddout=[
                nll_upper_bound,
                recon_term,
                kl_term,
                mse,
                mae,
                max_phi_sig,
                mean_phi_sig,
                min_phi_sig,
                max_prior_sig,
                mean_prior_sig,
                min_prior_sig,
                max_theta_sig,
                mean_theta_sig,
                min_theta_sig,
                max_x,
                mean_x,
                min_x,
                max_theta_mu,
                mean_theta_mu,
                min_theta_mu,  #0-17
                #binary_temp, corr_temp,
                theta_mu_temp,
                theta_sig_temp,  #17-20
                s_temp,
                z_temp,
                z_1_temp,
                x_pred_temp
                #phi_sig_output,phi_sig_output
            ],  ## added in order to explore the distributions
            indexSep=22,
            indexDDoutPlot=[(0, theta_mu_temp), (2, z_t_temp),
                            (3, prediction)],
            instancesPlot=[0, 150],  #, 80,150
            savedFolder=save_path,
            data=[Iterator(valid_data, batch_size)]),
        Picklize(freq=monitoring_freq, path=save_path),
        EarlyStopping(freq=monitoring_freq,
                      path=save_path,
                      channel=channel_name),
        WeightNorm()
    ]

    mainloop = Training(name=pkl_name,
                        data=Iterator(train_data, batch_size),
                        model=model,
                        optimizer=optimizer,
                        cost=nll_upper_bound,
                        outputs=[nll_upper_bound],
                        extension=extension)
    mainloop.run()
    fLog = open(save_path + '/output.csv', 'w')
    fLog.write("log,kl,nll_upper_bound,mse,mae\n")
    for i, item in enumerate(mainloop.trainlog.monitor['nll_upper_bound']):
        a = mainloop.trainlog.monitor['recon_term'][i]
        b = mainloop.trainlog.monitor['kl_term'][i]
        c = mainloop.trainlog.monitor['nll_upper_bound'][i]
        d = mainloop.trainlog.monitor['mse'][i]
        e = mainloop.trainlog.monitor['mae'][i]
        fLog.write("{},{},{},{},{}\n".format(a, b, c, d, e))