示例#1
0
    def get_updates(self, grads):
        """
        .. todo::

            WRITEME
        """
        updates = OrderedDict()
        i = sharedX(0., 'counter')
        i_t = i + 1.
        b1_t = self.b1**i_t
        b2_t = self.b2**i_t
        lr_t = self.lr * T.sqrt(1. - b2_t) / (1 - b1_t)
        #b1 = 1 - self.b1 * self.lambd**i

        for p, g in grads.items():
            lr_scaler = self.lr_scalers.get(str(p), 1.)
            m = sharedX(p.get_value() * 0.)
            v = sharedX(p.get_value() * 0.)
            #m_t = b1 * m + (1 - b1) * g
            m_t = self.b1 * m + (1 - self.b1) * g
            v_t = self.b2 * v + (1 - self.b2) * g**2
            g_t = m_t / (T.sqrt(v_t) + self.eps)
            p_t = p - lr_scaler * lr_t * g_t
            updates[m] = m_t
            updates[v] = v_t
            updates[p] = p_t

        updates[i] = i_t

        return updates
示例#2
0
文件: opt.py 项目: Beronx86/cle
    def get_updates(self, grads):
        """
        .. todo::

            WRITEME
        """
        updates = OrderedDict()
        i = sharedX(0., 'counter')
        i_t = i + 1.
        b1_t = self.b1**i_t
        b2_t = self.b2**i_t
        lr_t = self.lr * T.sqrt(1. - b2_t) / (1 - b1_t)
        #b1 = 1 - self.b1 * self.lambd**i

        for p, g in grads.items():
            lr_scaler = self.lr_scalers.get(str(p), 1.)
            m = sharedX(p.get_value() * 0.)
            v = sharedX(p.get_value() * 0.)
            #m_t = b1 * m + (1 - b1) * g
            m_t = self.b1 * m + (1 - self.b1) * g
            v_t = self.b2 * v + (1 - self.b2) * g**2
            g_t = m_t / (T.sqrt(v_t) + self.eps)
            p_t = p - lr_scaler * lr_t * g_t
            updates[m] = m_t
            updates[v] = v_t
            updates[p] = p_t

        updates[i] = i_t

        return updates
示例#3
0
文件: opt.py 项目: heeyoulchoi/cle
    def get_updates(self, grads):
        """
        .. todo::

            WRITEME
        """
        updates = OrderedDict()
        cnt = sharedX(0, 'counter')

        for p, g in grads.items():
            lr_scaler = self.lr_scalers.get(str(p), 1.)
            m = sharedX(p.get_value() * 0.)
            v = sharedX(p.get_value() * 0.)
            b1 = self.b1 * self.lambd**cnt
            m_t = b1 * m + (1 - b1) * g
            v_t = self.b2 * v + (1 - self.b2) * g**2
            m_t_hat = m_t / (1. - self.b1**(cnt + 1))
            v_t_hat = v_t / (1. - self.b2**(cnt + 1))
            g_t = m_t_hat / (T.sqrt(v_t_hat) + self.e)
            p_t = p - lr_scaler * self.lr * g_t
            updates[m] = m_t
            updates[v] = v_t
            updates[p] = p_t

        updates[cnt] = cnt + 1

        return updates
示例#4
0
 def __init__(self, rho=0.5, eps=1e-6, **kwargs):
     super(BatchNormLayer, self).__init__(**kwargs)
     self.rho = rho
     self.eps = eps
     self.mu = sharedX(InitCell('zeros').get(self.nout),
                       name='mu_' + self.name)
     self.sigma = sharedX(InitCell('ones').get(self.nout),
                          name='sigma_' + self.name)
示例#5
0
    def __init__(self, lr, lr_scalers=None):
        """
        .. todo::

            WRITEME
        """
        self.lr = sharedX(lr)
        if lr_scalers is not None:
            self.lr_scalers = lr_scalers
        else:
            self.lr_scalers = OrderedDict()
示例#6
0
文件: opt.py 项目: orangelpai/cle
    def __init__(self, lr, lr_scalers=None):
        """
        .. todo::

            WRITEME
        """
        self.lr = sharedX(lr)
        if lr_scalers is not None:
            self.lr_scalers = lr_scalers
        else:
            self.lr_scalers = OrderedDict()
示例#7
0
文件: opt.py 项目: orangelpai/cle
    def get_updates(self, grads):
        """
        .. todo::

            WRITEME
        """
        updates = OrderedDict()
        for p, g in grads.items():
            lr_scaler = self.lr_scalers.get(str(p), 1.)
            u = sharedX(p.get_value() * 0.)
            avg_grad = sharedX(p.get_value() * 0.)
            sqr_grad = sharedX(p.get_value() * 0.)
            avg_grad_t = self.sec_mom * avg_grad + (1 - self.sec_mom) * g
            sqr_grad_t = self.sec_mom * sqr_grad + (1 - self.sec_mom) * g**2
            g_t = g / T.sqrt(sqr_grad_t - avg_grad_t**2 + self.e)
            u_t = self.mom * u - lr_scaler * self.lr * g_t
            p_t = p + u_t
            updates[avg_grad] = avg_grad_t
            updates[sqr_grad] = sqr_grad_t
            updates[u] = u_t
            updates[p] = p_t
        return updates
示例#8
0
    def get_updates(self, grads):
        """
        .. todo::

            WRITEME
        """
        updates = OrderedDict()
        for p, g in grads.items():
            lr_scaler = self.lr_scalers.get(str(p), 1.)
            u = sharedX(p.get_value() * 0.)
            avg_grad = sharedX(p.get_value() * 0.)
            sqr_grad = sharedX(p.get_value() * 0.)
            avg_grad_t = self.sec_mom * avg_grad + (1 - self.sec_mom) * g
            sqr_grad_t = self.sec_mom * sqr_grad + (1 - self.sec_mom) * g**2
            g_t = g / T.sqrt(sqr_grad_t - avg_grad_t**2 + self.e)
            u_t = self.mom * u - lr_scaler * self.lr * g_t
            p_t = p + u_t
            updates[avg_grad] = avg_grad_t
            updates[sqr_grad] = sqr_grad_t
            updates[u] = u_t
            updates[p] = p_t
        return updates
示例#9
0
文件: opt.py 项目: lipengyu/cle
    def get_updates(self, grads):
        """
        .. todo::

            WRITEME
        """
        updates = OrderedDict()
        g_tt = OrderedDict()
        cnt = sharedX(0, 'counter')
        for p, g in grads.items():
            lr_scaler = self.lr_scalers.get(str(p), 1.)
            m = sharedX(p.get_value() * 0.)
            v = sharedX(p.get_value() * 0.)
            b1 = self.b1 * self.lambd**cnt
            m_t = b1 * m + (1 - b1) * g
            v_t = self.b2 * v + (1 - self.b2) * g**2
            m_t_hat = m_t / (1. - self.b1**(cnt + 1))
            v_t_hat = v_t / (1. - self.b2**(cnt + 1))
            g_t = m_t_hat / (T.sqrt(v_t_hat) + self.e)
            p_t = p - lr_scaler * self.lr * g_t
            g_tt[p] = g_t
            updates[m] = m_t
            updates[v] = v_t
            updates[p] = p_t
        if self.post_clip:
            g_norm = sum([T.sqr(x/self.batch_size).sum()
                          for x in g_tt.values()])
            not_finite = T.or_(T.isnan(g_norm), T.isinf(g_norm))
            g_norm = T.sqrt(g_norm)
            scaler = self.scaler / T.maximum(self.scaler, g_norm)
            for p, g in g_tt.items():
                lr_scaler = self.lr_scalers.get(str(p), 1.)
                p_t = p - lr_scaler * self.lr * g * scaler
                updates[p] = p_t
        updates[cnt] = cnt + 1
        return updates
示例#10
0
文件: opt.py 项目: lipengyu/cle
    def __init__(self, lr, lr_scalers=None, post_clip=0,
                 scaler=5, batch_size=1):
        """
        .. todo::

            WRITEME
        """
        self.lr = sharedX(lr)
        if lr_scalers is not None:
            self.lr_scalers = lr_scalers
        else:
            self.lr_scalers = OrderedDict()
        self.post_clip = post_clip
        self.scaler = scaler
        self.batch_size = batch_size
示例#11
0
    def get_updates(self, grads):
        """
        .. todo::

            WRITEME
        """
        updates = OrderedDict()
        cnt = sharedX(0, 'counter')
        for p, g in grads.items():
            lr_scaler = self.lr_scalers.get(str(p), 1.)
            m = sharedX(p.get_value() * 0.)
            v = sharedX(p.get_value() * 0.)
            b1 = self.b1 * self.lambd**cnt
            m_t = b1 * m + (1 - b1) * g
            v_t = self.b2 * v + (1 - self.b2) * g**2
            m_t_hat = m_t / (1. - self.b1**(cnt + 1))
            v_t_hat = v_t / (1. - self.b2**(cnt + 1))
            g_t = m_t_hat / (T.sqrt(v_t_hat) + self.e)
            p_t = p - lr_scaler * self.lr * g_t
            updates[m] = m_t
            updates[v] = v_t
            updates[p] = p_t
        updates[cnt] = cnt + 1
        return updates
示例#12
0
文件: opt.py 项目: orangelpai/cle
    def get_updates(self, grads):
        """
        .. todo::

            WRITEME
        """
        updates = OrderedDict()
        for p, g in grads.items():
            lr_scaler = self.lr_scalers.get(str(p), 1.)
            u = sharedX(p.get_value() * 0.)
            u_t = self.mom * u - self.lr * g
            if self.nesterov:
                u_t = self.mom * u_t - lr_scaler * self.lr * g
            p_t = p + u_t
            updates[u] = u_t
            updates[p] = p_t
        return updates
示例#13
0
    def get_updates(self, grads):
        """
        .. todo::

            WRITEME
        """
        updates = OrderedDict()
        for p, g in grads.items():
            lr_scaler = self.lr_scalers.get(str(p), 1.)
            u = sharedX(p.get_value() * 0.)
            u_t = self.mom * u - self.lr * g
            if self.nesterov:
                u_t = self.mom * u_t - lr_scaler * self.lr * g
            p_t = p + u_t
            updates[u] = u_t
            updates[p] = p_t
        return updates
示例#14
0
                                init_W=init_W,
                                init_b=init_b_sig)

nodes = [lstm_1, lstm_2, lstm_3, prior, kl,
         x_1, x_2, x_3, x_4,
         z_1, z_2, z_3, z_4,
         phi_1, phi_2, phi_3, phi_4, phi_mu, phi_sig,
         prior_1, prior_2, prior_3, prior_4, prior_mu, prior_sig,
         theta_1, theta_2, theta_3, theta_4, theta_mu, theta_sig]

for node in nodes:
    node.initialize()

params = flatten([node.get_params().values() for node in nodes])

step_count = sharedX(0, name='step_count')
last_lstm_1 = np.zeros((batch_size, lstm_1_dim*2), dtype=theano.config.floatX)
last_lstm_2 = np.zeros((batch_size, lstm_2_dim*2), dtype=theano.config.floatX)
last_lstm_3 = np.zeros((batch_size, lstm_3_dim*2), dtype=theano.config.floatX)
lstm_1_tm1 = sharedX(last_lstm_1, name='lstm_1_tm1')
lstm_2_tm1 = sharedX(last_lstm_2, name='lstm_2_tm1')
lstm_3_tm1 = sharedX(last_lstm_3, name='lstm_3_tm1')
update_list = [step_count, lstm_1_tm1, lstm_2_tm1, lstm_3_tm1]

step_count = T.switch(T.le(step_count, reset_freq), step_count + 1, 0)
s_1_0 = T.switch(T.or_(T.cast(T.eq(step_count, 0), 'int32'),
                       T.cast(T.eq(T.sum(lstm_1_tm1), 0.), 'int32')),
                 lstm_1.get_init_state(batch_size), lstm_1_tm1)
s_2_0 = T.switch(T.or_(T.cast(T.eq(step_count, 0), 'int32'),
                       T.cast(T.eq(T.sum(lstm_2_tm1), 0.), 'int32')),
                 lstm_2.get_init_state(batch_size), lstm_2_tm1)
示例#15
0
 def setX(self, x, name=None):
     return sharedX(x, name)
示例#16
0
 def getX(self, shape, name=None):
     return sharedX(self.init_param(shape), name)
示例#17
0
                            nout=k,
                            unit='softmax',
                            init_W=init_W,
                            init_b=init_b)

nodes = [rnn,
         x_1, x_2, x_3, x_4,
         theta_1, theta_2, theta_3, theta_4, theta_mu, theta_sig, coeff]

params = OrderedDict()
for node in nodes:
    if node.initialize() is not None:
        params.update(node.initialize())
params = init_tparams(params)

step_count = sharedX(0, name='step_count')
last_rnn = np.zeros((batch_size, rnn_dim*2), dtype=theano.config.floatX)
rnn_tm1 = sharedX(last_rnn, name='rnn_tm1')
shared_updates = OrderedDict()
shared_updates[step_count] = step_count + 1

s_0 = T.switch(T.eq(T.mod(step_count, reset_freq), 0),
               rnn.get_init_state(batch_size), rnn_tm1)

x_1_temp = x_1.fprop([x], params)
x_2_temp = x_2.fprop([x_1_temp], params)
x_3_temp = x_3.fprop([x_2_temp], params)
x_4_temp = x_4.fprop([x_3_temp], params)


def inner_fn(x_t, s_tm1):
示例#18
0
    phi_mu,
    phi_sig,
    theta_1,
    theta_2,
    theta_3,
    theta_4,
    theta_mu,
    theta_sig,
]

for node in nodes:
    node.initialize()

params = flatten([node.get_params().values() for node in nodes])

step_count = sharedX(0, name="step_count")
last_encoder = np.zeros((batch_size, encoder_dim * 2), dtype=theano.config.floatX)
last_decoder = np.zeros((batch_size, decoder_dim * 2), dtype=theano.config.floatX)
encoder_tm1 = sharedX(last_encoder, name="encoder_tm1")
decoder_tm1 = sharedX(last_decoder, name="decoder_tm1")
update_list = [step_count, encoder_tm1, decoder_tm1]

step_count = T.switch(T.le(step_count, reset_freq), step_count + 1, 0)
enc_0 = T.switch(
    T.or_(T.cast(T.eq(step_count, 0), "int32"), T.cast(T.eq(T.sum(encoder_tm1), 0.0), "int32")),
    encoder.get_init_state(batch_size),
    encoder_tm1,
)
dec_0 = T.switch(
    T.or_(T.cast(T.eq(step_count, 0), "int32"), T.cast(T.eq(T.sum(decoder_tm1), 0.0), "int32")),
    decoder.get_init_state(batch_size),
示例#19
0
def main(args):

    trial = int(args['trial'])
    pkl_name = 'vrnn_gauss_%d' % trial
    channel_name = 'valid_nll_upper_bound'

    data_path = args['data_path']
    save_path = args['save_path']

    monitoring_freq = int(args['monitoring_freq'])
    force_saving_freq = int(args['force_saving_freq'])
    reset_freq = int(args['reset_freq'])
    epoch = int(args['epoch'])
    batch_size = int(args['batch_size'])
    m_batch_size = int(args['m_batch_size'])
    x_dim = int(args['x_dim'])
    z_dim = int(args['z_dim'])
    rnn_dim = int(args['rnn_dim'])
    lr = float(args['lr'])
    debug = int(args['debug'])

    print "trial no. %d" % trial
    print "batch size %d" % batch_size
    print "learning rate %f" % lr
    print "saving pkl file '%s'" % pkl_name
    print "to the save path '%s'" % save_path

    q_z_dim = 500
    p_z_dim = 500
    p_x_dim = 600
    x2s_dim = 600
    z2s_dim = 500
    target_dim = x_dim

    file_name = 'blizzard_unseg_tbptt'
    normal_params = np.load(data_path + file_name + '_normal.npz')
    X_mean = normal_params['X_mean']
    X_std = normal_params['X_std']

    model = Model()
    train_data = Blizzard_tbptt(name='train',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    valid_data = Blizzard_tbptt(name='valid',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    x = train_data.theano_vars()
    m_x = valid_data.theano_vars()

    if debug:
        x.tag.test_value = np.zeros((15, batch_size, x_dim),
                                    dtype=theano.config.floatX)
        m_x.tag.test_value = np.zeros((15, m_batch_size, x_dim),
                                      dtype=theano.config.floatX)

    init_W = InitCell('rand')
    init_U = InitCell('ortho')
    init_b = InitCell('zeros')
    init_b_sig = InitCell('const', mean=0.6)

    x_1 = FullyConnectedLayer(name='x_1',
                              parent=['x_t'],
                              parent_dim=[x_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_2 = FullyConnectedLayer(name='x_2',
                              parent=['x_1'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_3 = FullyConnectedLayer(name='x_3',
                              parent=['x_2'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_4 = FullyConnectedLayer(name='x_4',
                              parent=['x_3'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_1 = FullyConnectedLayer(name='z_1',
                              parent=['z_t'],
                              parent_dim=[z_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_2 = FullyConnectedLayer(name='z_2',
                              parent=['z_1'],
                              parent_dim=[z2s_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_3 = FullyConnectedLayer(name='z_3',
                              parent=['z_2'],
                              parent_dim=[z2s_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_4 = FullyConnectedLayer(name='z_4',
                              parent=['z_3'],
                              parent_dim=[z2s_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    rnn = LSTM(name='rnn',
               parent=['x_4', 'z_4'],
               parent_dim=[x2s_dim, z2s_dim],
               nout=rnn_dim,
               unit='tanh',
               init_W=init_W,
               init_U=init_U,
               init_b=init_b)

    phi_1 = FullyConnectedLayer(name='phi_1',
                                parent=['x_4', 's_tm1'],
                                parent_dim=[x2s_dim, rnn_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_2 = FullyConnectedLayer(name='phi_2',
                                parent=['phi_1'],
                                parent_dim=[q_z_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_3 = FullyConnectedLayer(name='phi_3',
                                parent=['phi_2'],
                                parent_dim=[q_z_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_4 = FullyConnectedLayer(name='phi_4',
                                parent=['phi_3'],
                                parent_dim=[q_z_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_mu = FullyConnectedLayer(name='phi_mu',
                                 parent=['phi_4'],
                                 parent_dim=[q_z_dim],
                                 nout=z_dim,
                                 unit='linear',
                                 init_W=init_W,
                                 init_b=init_b)

    phi_sig = FullyConnectedLayer(name='phi_sig',
                                  parent=['phi_4'],
                                  parent_dim=[q_z_dim],
                                  nout=z_dim,
                                  unit='softplus',
                                  cons=1e-4,
                                  init_W=init_W,
                                  init_b=init_b_sig)

    prior_1 = FullyConnectedLayer(name='prior_1',
                                  parent=['s_tm1'],
                                  parent_dim=[rnn_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_2 = FullyConnectedLayer(name='prior_2',
                                  parent=['prior_1'],
                                  parent_dim=[p_z_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_3 = FullyConnectedLayer(name='prior_3',
                                  parent=['prior_2'],
                                  parent_dim=[p_z_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_4 = FullyConnectedLayer(name='prior_4',
                                  parent=['prior_3'],
                                  parent_dim=[p_z_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_mu = FullyConnectedLayer(name='prior_mu',
                                   parent=['prior_4'],
                                   parent_dim=[p_z_dim],
                                   nout=z_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    prior_sig = FullyConnectedLayer(name='prior_sig',
                                    parent=['prior_4'],
                                    parent_dim=[p_z_dim],
                                    nout=z_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    theta_1 = FullyConnectedLayer(name='theta_1',
                                  parent=['z_4', 's_tm1'],
                                  parent_dim=[z2s_dim, rnn_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_2 = FullyConnectedLayer(name='theta_2',
                                  parent=['theta_1'],
                                  parent_dim=[p_x_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_3 = FullyConnectedLayer(name='theta_3',
                                  parent=['theta_2'],
                                  parent_dim=[p_x_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_4 = FullyConnectedLayer(name='theta_4',
                                  parent=['theta_3'],
                                  parent_dim=[p_x_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_mu = FullyConnectedLayer(name='theta_mu',
                                   parent=['theta_4'],
                                   parent_dim=[p_x_dim],
                                   nout=target_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    theta_sig = FullyConnectedLayer(name='theta_sig',
                                    parent=['theta_4'],
                                    parent_dim=[p_x_dim],
                                    nout=target_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    nodes = [
        rnn, x_1, x_2, x_3, x_4, z_1, z_2, z_3, z_4, phi_1, phi_2, phi_3,
        phi_4, phi_mu, phi_sig, prior_1, prior_2, prior_3, prior_4, prior_mu,
        prior_sig, theta_1, theta_2, theta_3, theta_4, theta_mu, theta_sig
    ]

    params = OrderedDict()

    for node in nodes:
        if node.initialize() is not None:
            params.update(node.initialize())

    params = init_tparams(params)

    step_count = sharedX(0, name='step_count')
    last_rnn = np.zeros((batch_size, rnn_dim * 2), dtype=theano.config.floatX)
    rnn_tm1 = sharedX(last_rnn, name='rnn_tm1')
    shared_updates = OrderedDict()
    shared_updates[step_count] = step_count + 1

    # Resets / Initializes the cell-state or the memory-state of each LSTM to
    # zero.
    s_0 = T.switch(T.eq(T.mod(step_count, reset_freq), 0),
                   rnn.get_init_state(batch_size), rnn_tm1)

    # Forward Propagate the input to get more complex features for
    # every time step.
    x_1_temp = x_1.fprop([x], params)
    x_2_temp = x_2.fprop([x_1_temp], params)
    x_3_temp = x_3.fprop([x_2_temp], params)
    x_4_temp = x_4.fprop([x_3_temp], params)

    def inner_fn(x_t, s_tm1):

        # Generate the mean and standard deviation of the
        # latent variables Z_t | X_t for every time-step of the LSTM.
        # This is a function of the input and the hidden state of the previous
        # time step.
        phi_1_t = phi_1.fprop([x_t, s_tm1], params)
        phi_2_t = phi_2.fprop([phi_1_t], params)
        phi_3_t = phi_3.fprop([phi_2_t], params)
        phi_4_t = phi_4.fprop([phi_3_t], params)
        phi_mu_t = phi_mu.fprop([phi_4_t], params)
        phi_sig_t = phi_sig.fprop([phi_4_t], params)

        # Prior on the latent variables at every time-step
        # Dependent only on the hidden-step.
        prior_1_t = prior_1.fprop([s_tm1], params)
        prior_2_t = prior_2.fprop([prior_1_t], params)
        prior_3_t = prior_3.fprop([prior_2_t], params)
        prior_4_t = prior_4.fprop([prior_3_t], params)
        prior_mu_t = prior_mu.fprop([prior_4_t], params)
        prior_sig_t = prior_sig.fprop([prior_4_t], params)

        # Sample from the latent distibution with mean phi_mu_t
        # and std phi_sig_t
        z_t = Gaussian_sample(phi_mu_t, phi_sig_t)

        # h_t = f(h_(t-1)), z_t, x_t)
        z_1_t = z_1.fprop([z_t], params)
        z_2_t = z_2.fprop([z_1_t], params)
        z_3_t = z_3.fprop([z_2_t], params)
        z_4_t = z_4.fprop([z_3_t], params)

        s_t = rnn.fprop([[x_t, z_4_t], [s_tm1]], params)

        return s_t, phi_mu_t, phi_sig_t, prior_mu_t, prior_sig_t, z_4_t

    # Iterate over every time-step
    ((s_temp, phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp, z_4_temp), updates) =\
        theano.scan(fn=inner_fn,
                    sequences=[x_4_temp],
                    outputs_info=[s_0, None, None, None, None, None])

    for k, v in updates.iteritems():
        k.default_update = v

    shared_updates[rnn_tm1] = s_temp[-1]
    s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)

    # Generate the output distribution at every time-step.
    # This is as a function of the latent variables and the hidden-state at
    # every time-step.
    theta_1_temp = theta_1.fprop([z_4_temp, s_temp], params)
    theta_2_temp = theta_2.fprop([theta_1_temp], params)
    theta_3_temp = theta_3.fprop([theta_2_temp], params)
    theta_4_temp = theta_4.fprop([theta_3_temp], params)
    theta_mu_temp = theta_mu.fprop([theta_4_temp], params)
    theta_sig_temp = theta_sig.fprop([theta_4_temp], params)

    kl_temp = KLGaussianGaussian(phi_mu_temp, phi_sig_temp, prior_mu_temp,
                                 prior_sig_temp)

    recon = Gaussian(x, theta_mu_temp, theta_sig_temp)
    recon_term = recon.mean()
    kl_term = kl_temp.mean()
    nll_upper_bound = recon_term + kl_term
    nll_upper_bound.name = 'nll_upper_bound'

    # Forward-propagation of the validation data.
    m_x_1_temp = x_1.fprop([m_x], params)
    m_x_2_temp = x_2.fprop([m_x_1_temp], params)
    m_x_3_temp = x_3.fprop([m_x_2_temp], params)
    m_x_4_temp = x_4.fprop([m_x_3_temp], params)

    m_s_0 = rnn.get_init_state(m_batch_size)

    # Get the hidden-states, conditional mean, standard deviation, prior mean
    # and prior standard deviation of the latent variables at every time-step.
    ((m_s_temp, m_phi_mu_temp, m_phi_sig_temp, m_prior_mu_temp, m_prior_sig_temp, m_z_4_temp), m_updates) =\
        theano.scan(fn=inner_fn,
                    sequences=[m_x_4_temp],
                    outputs_info=[m_s_0, None, None, None, None, None])

    for k, v in m_updates.iteritems():
        k.default_update = v

    # Get the inferred mean (X_t | Z_t) at every time-step of the validation
    # data.
    m_s_temp = concatenate([m_s_0[None, :, :], m_s_temp[:-1]], axis=0)
    m_theta_1_temp = theta_1.fprop([m_z_4_temp, m_s_temp], params)
    m_theta_2_temp = theta_2.fprop([m_theta_1_temp], params)
    m_theta_3_temp = theta_3.fprop([m_theta_2_temp], params)
    m_theta_4_temp = theta_4.fprop([m_theta_3_temp], params)
    m_theta_mu_temp = theta_mu.fprop([m_theta_4_temp], params)
    m_theta_sig_temp = theta_sig.fprop([m_theta_4_temp], params)

    # Compute the data log-likelihood + KL-divergence on the validation data.
    m_kl_temp = KLGaussianGaussian(m_phi_mu_temp, m_phi_sig_temp,
                                   m_prior_mu_temp, m_prior_sig_temp)

    m_recon = Gaussian(m_x, m_theta_mu_temp, m_theta_sig_temp)
    m_recon_term = m_recon.mean()
    m_kl_term = m_kl_temp.mean()
    m_nll_upper_bound = m_recon_term + m_kl_term
    m_nll_upper_bound.name = 'nll_upper_bound'
    m_recon_term.name = 'recon_term'
    m_kl_term.name = 'kl_term'

    max_x = m_x.max()
    mean_x = m_x.mean()
    min_x = m_x.min()
    max_x.name = 'max_x'
    mean_x.name = 'mean_x'
    min_x.name = 'min_x'

    max_theta_mu = m_theta_mu_temp.max()
    mean_theta_mu = m_theta_mu_temp.mean()
    min_theta_mu = m_theta_mu_temp.min()
    max_theta_mu.name = 'max_theta_mu'
    mean_theta_mu.name = 'mean_theta_mu'
    min_theta_mu.name = 'min_theta_mu'

    max_theta_sig = m_theta_sig_temp.max()
    mean_theta_sig = m_theta_sig_temp.mean()
    min_theta_sig = m_theta_sig_temp.min()
    max_theta_sig.name = 'max_theta_sig'
    mean_theta_sig.name = 'mean_theta_sig'
    min_theta_sig.name = 'min_theta_sig'

    max_phi_sig = m_phi_sig_temp.max()
    mean_phi_sig = m_phi_sig_temp.mean()
    min_phi_sig = m_phi_sig_temp.min()
    max_phi_sig.name = 'max_phi_sig'
    mean_phi_sig.name = 'mean_phi_sig'
    min_phi_sig.name = 'min_phi_sig'

    max_prior_sig = m_prior_sig_temp.max()
    mean_prior_sig = m_prior_sig_temp.mean()
    min_prior_sig = m_prior_sig_temp.min()
    max_prior_sig.name = 'max_prior_sig'
    mean_prior_sig.name = 'mean_prior_sig'
    min_prior_sig.name = 'min_prior_sig'

    model.inputs = [x]
    model.params = params
    model.nodes = nodes
    model.set_updates(shared_updates)

    optimizer = Adam(lr=lr)

    monitor_fn = theano.function(
        inputs=[m_x],
        outputs=[
            m_nll_upper_bound, m_recon_term, m_kl_term, max_phi_sig,
            mean_phi_sig, min_phi_sig, max_prior_sig, mean_prior_sig,
            min_prior_sig, max_theta_sig, mean_theta_sig, min_theta_sig, max_x,
            mean_x, min_x, max_theta_mu, mean_theta_mu, min_theta_mu
        ],
        on_unused_input='ignore')

    extension = [
        GradientClipping(batch_size=batch_size, check_nan=1),
        EpochCount(epoch),
        Monitoring(freq=monitoring_freq,
                   monitor_fn=monitor_fn,
                   ddout=[
                       m_nll_upper_bound, m_recon_term, m_kl_term, max_phi_sig,
                       mean_phi_sig, min_phi_sig, max_prior_sig,
                       mean_prior_sig, min_prior_sig, max_theta_sig,
                       mean_theta_sig, min_theta_sig, max_x, mean_x, min_x,
                       max_theta_mu, mean_theta_mu, min_theta_mu
                   ],
                   data=[
                       Iterator(train_data, m_batch_size, start=0, end=112640),
                       Iterator(valid_data,
                                m_batch_size,
                                start=2040064,
                                end=2152704)
                   ]),
        Picklize(freq=monitoring_freq,
                 force_save_freq=force_saving_freq,
                 path=save_path),
        EarlyStopping(freq=monitoring_freq,
                      force_save_freq=force_saving_freq,
                      path=save_path,
                      channel=channel_name),
        WeightNorm()
    ]

    mainloop = Training(name=pkl_name,
                        data=Iterator(train_data,
                                      batch_size,
                                      start=0,
                                      end=2040064),
                        model=model,
                        optimizer=optimizer,
                        cost=nll_upper_bound,
                        outputs=[nll_upper_bound],
                        extension=extension)
    mainloop.run()
示例#20
0
 def setX(self, x, name=None):
     return sharedX(x, name)
示例#21
0
                                parent_dim=[decoder_dim],
                                nout=target_size,
                                unit='softplus',
                                cons=1e-4,
                                init_W=init_W,
                                init_b=init_b_sig)

nodes = [encoder, decoder, prior, kl,
         phi_mu, phi_sig, theta_mu, theta_sig]

for node in nodes:
    node.initialize()

params = flatten([node.get_params().values() for node in nodes])

step_count = sharedX(0, name='step_count')
last_encoder = np.zeros((batch_size, encoder_dim*2), dtype=theano.config.floatX)
last_decoder = np.zeros((batch_size, decoder_dim*2), dtype=theano.config.floatX)
encoder_tm1 = sharedX(last_encoder, name='encoder_tm1')
decoder_tm1 = sharedX(last_decoder, name='decoder_tm1')
update_list = [step_count, encoder_tm1, decoder_tm1]

step_count = T.switch(T.le(step_count, reset_freq), step_count + 1, 0)
enc_0 = T.switch(T.or_(T.cast(T.eq(step_count, 0), 'int32'),
                       T.cast(T.eq(T.sum(encoder_tm1), 0.), 'int32')),
                 encoder.get_init_state(batch_size), encoder_tm1)
dec_0 = T.switch(T.or_(T.cast(T.eq(step_count, 0), 'int32'),
                       T.cast(T.eq(T.sum(decoder_tm1), 0.), 'int32')),
                 decoder.get_init_state(batch_size), decoder_tm1)

x_shape = x.shape
示例#22
0
def main(args):

    trial = int(args['trial'])
    pkl_name = 'rnn_gauss_%d' % trial
    channel_name = 'valid_nll'

    data_path = args['data_path']
    save_path = args['save_path']

    monitoring_freq = int(args['monitoring_freq'])
    force_saving_freq = int(args['force_saving_freq'])
    reset_freq = int(args['reset_freq'])
    epoch = int(args['epoch'])
    batch_size = int(args['batch_size'])
    m_batch_size = int(args['m_batch_size'])
    x_dim = int(args['x_dim'])
    z_dim = int(args['z_dim'])
    rnn_dim = int(args['rnn_dim'])
    lr = float(args['lr'])
    debug = int(args['debug'])

    print "trial no. %d" % trial
    print "batch size %d" % batch_size
    print "learning rate %f" % lr
    print "saving pkl file '%s'" % pkl_name
    print "to the save path '%s'" % save_path

    x2s_dim = 800
    s2x_dim = 800
    target_dim = x_dim

    file_name = 'blizzard_unseg_tbptt'
    normal_params = np.load(data_path + file_name + '_normal.npz')
    X_mean = normal_params['X_mean']
    X_std = normal_params['X_std']

    model = Model()
    train_data = Blizzard_tbptt(name='train',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    valid_data = Blizzard_tbptt(name='valid',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    x = train_data.theano_vars()
    m_x = valid_data.theano_vars()

    if debug:
        x.tag.test_value = np.zeros((15, batch_size, x_dim),
                                    dtype=theano.config.floatX)
        m_x.tag.test_value = np.zeros((15, m_batch_size, x_dim),
                                      dtype=theano.config.floatX)

    init_W = InitCell('rand')
    init_U = InitCell('ortho')
    init_b = InitCell('zeros')
    init_b_sig = InitCell('const', mean=0.6)

    x_1 = FullyConnectedLayer(name='x_1',
                              parent=['x_t'],
                              parent_dim=[x_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_2 = FullyConnectedLayer(name='x_2',
                              parent=['x_1'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_3 = FullyConnectedLayer(name='x_3',
                              parent=['x_2'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_4 = FullyConnectedLayer(name='x_4',
                              parent=['x_3'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    rnn = LSTM(name='rnn',
               parent=['x_4'],
               parent_dim=[x2s_dim],
               nout=rnn_dim,
               unit='tanh',
               init_W=init_W,
               init_U=init_U,
               init_b=init_b)

    theta_1 = FullyConnectedLayer(name='theta_1',
                                  parent=['s_tm1'],
                                  parent_dim=[rnn_dim],
                                  nout=s2x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_2 = FullyConnectedLayer(name='theta_2',
                                  parent=['theta_1'],
                                  parent_dim=[s2x_dim],
                                  nout=s2x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_3 = FullyConnectedLayer(name='theta_3',
                                  parent=['theta_2'],
                                  parent_dim=[s2x_dim],
                                  nout=s2x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_4 = FullyConnectedLayer(name='theta_4',
                                  parent=['theta_3'],
                                  parent_dim=[s2x_dim],
                                  nout=s2x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_mu = FullyConnectedLayer(name='theta_mu',
                                   parent=['theta_4'],
                                   parent_dim=[s2x_dim],
                                   nout=target_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    theta_sig = FullyConnectedLayer(name='theta_sig',
                                    parent=['theta_4'],
                                    parent_dim=[s2x_dim],
                                    nout=target_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    nodes = [
        rnn, x_1, x_2, x_3, x_4, theta_1, theta_2, theta_3, theta_4, theta_mu,
        theta_sig
    ]

    params = OrderedDict()
    for node in nodes:
        if node.initialize() is not None:
            params.update(node.initialize())
    params = init_tparams(params)

    step_count = sharedX(0, name='step_count')
    last_rnn = np.zeros((batch_size, rnn_dim * 2), dtype=theano.config.floatX)
    rnn_tm1 = sharedX(last_rnn, name='rnn_tm1')
    shared_updates = OrderedDict()
    shared_updates[step_count] = step_count + 1

    s_0 = T.switch(T.eq(T.mod(step_count, reset_freq), 0),
                   rnn.get_init_state(batch_size), rnn_tm1)

    x_1_temp = x_1.fprop([x], params)
    x_2_temp = x_2.fprop([x_1_temp], params)
    x_3_temp = x_3.fprop([x_2_temp], params)
    x_4_temp = x_4.fprop([x_3_temp], params)

    def inner_fn(x_t, s_tm1):

        s_t = rnn.fprop([[x_t], [s_tm1]], params)

        return s_t

    (s_temp, updates) = theano.scan(fn=inner_fn,
                                    sequences=[x_4_temp],
                                    outputs_info=[s_0])

    for k, v in updates.iteritems():
        k.default_update = v

    shared_updates[rnn_tm1] = s_temp[-1]
    s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)
    theta_1_temp = theta_1.fprop([s_temp], params)
    theta_2_temp = theta_2.fprop([theta_1_temp], params)
    theta_3_temp = theta_3.fprop([theta_2_temp], params)
    theta_4_temp = theta_4.fprop([theta_3_temp], params)
    theta_mu_temp = theta_mu.fprop([theta_4_temp], params)
    theta_sig_temp = theta_sig.fprop([theta_4_temp], params)

    recon = Gaussian(x, theta_mu_temp, theta_sig_temp)
    recon_term = recon.mean()
    recon_term.name = 'nll'

    m_x_1_temp = x_1.fprop([m_x], params)
    m_x_2_temp = x_2.fprop([m_x_1_temp], params)
    m_x_3_temp = x_3.fprop([m_x_2_temp], params)
    m_x_4_temp = x_4.fprop([m_x_3_temp], params)

    m_s_0 = rnn.get_init_state(m_batch_size)

    (m_s_temp, m_updates) = theano.scan(fn=inner_fn,
                                        sequences=[m_x_4_temp],
                                        outputs_info=[m_s_0])

    for k, v in m_updates.iteritems():
        k.default_update = v

    m_s_temp = concatenate([m_s_0[None, :, :], m_s_temp[:-1]], axis=0)
    m_theta_1_temp = theta_1.fprop([m_s_temp], params)
    m_theta_2_temp = theta_2.fprop([m_theta_1_temp], params)
    m_theta_3_temp = theta_3.fprop([m_theta_2_temp], params)
    m_theta_4_temp = theta_4.fprop([m_theta_3_temp], params)
    m_theta_mu_temp = theta_mu.fprop([m_theta_4_temp], params)
    m_theta_sig_temp = theta_sig.fprop([m_theta_4_temp], params)

    m_recon = Gaussian(m_x, m_theta_mu_temp, m_theta_sig_temp)
    m_recon_term = m_recon.mean()
    m_recon_term.name = 'nll'

    max_x = m_x.max()
    mean_x = m_x.mean()
    min_x = m_x.min()
    max_x.name = 'max_x'
    mean_x.name = 'mean_x'
    min_x.name = 'min_x'

    max_theta_mu = m_theta_mu_temp.max()
    mean_theta_mu = m_theta_mu_temp.mean()
    min_theta_mu = m_theta_mu_temp.min()
    max_theta_mu.name = 'max_theta_mu'
    mean_theta_mu.name = 'mean_theta_mu'
    min_theta_mu.name = 'min_theta_mu'

    max_theta_sig = m_theta_sig_temp.max()
    mean_theta_sig = m_theta_sig_temp.mean()
    min_theta_sig = m_theta_sig_temp.min()
    max_theta_sig.name = 'max_theta_sig'
    mean_theta_sig.name = 'mean_theta_sig'
    min_theta_sig.name = 'min_theta_sig'

    model.inputs = [x]
    model.params = params
    model.nodes = nodes
    model.set_updates(shared_updates)

    optimizer = Adam(lr=lr)

    monitor_fn = theano.function(inputs=[m_x],
                                 outputs=[
                                     m_recon_term, max_theta_sig,
                                     mean_theta_sig, min_theta_sig, max_x,
                                     mean_x, min_x, max_theta_mu,
                                     mean_theta_mu, min_theta_mu
                                 ],
                                 on_unused_input='ignore')

    extension = [
        GradientClipping(batch_size=batch_size, check_nan=1),
        EpochCount(epoch),
        Monitoring(freq=monitoring_freq,
                   monitor_fn=monitor_fn,
                   ddout=[
                       m_recon_term, max_theta_sig, mean_theta_sig,
                       min_theta_sig, max_x, mean_x, min_x, max_theta_mu,
                       mean_theta_mu, min_theta_mu
                   ],
                   data=[
                       Iterator(train_data, m_batch_size, start=0, end=112640),
                       Iterator(valid_data,
                                m_batch_size,
                                start=2040064,
                                end=2152704)
                   ]),
        Picklize(freq=monitoring_freq,
                 force_save_freq=force_saving_freq,
                 path=save_path),
        EarlyStopping(freq=monitoring_freq,
                      force_save_freq=force_saving_freq,
                      path=save_path,
                      channel=channel_name),
        WeightNorm()
    ]

    mainloop = Training(name=pkl_name,
                        data=Iterator(train_data,
                                      batch_size,
                                      start=0,
                                      end=2040064),
                        model=model,
                        optimizer=optimizer,
                        cost=recon_term,
                        outputs=[recon_term],
                        extension=extension)
    mainloop.run()
示例#23
0
                                init_W=init_W,
                                init_b=init_b_sig)

nodes = [main_lstm, prior, kl,
         x_1, x_2, x_3, x_4,
         z_1, z_2, z_3, z_4,
         phi_1, phi_2, phi_3, phi_4, phi_mu, phi_sig,
         prior_1, prior_2, prior_3, prior_4, prior_mu, prior_sig,
         theta_1, theta_2, theta_3, theta_4, theta_mu, theta_sig]

for node in nodes:
    node.initialize()

params = flatten([node.get_params().values() for node in nodes])

step_count = sharedX(0, name='step_count')
last_main_lstm = np.zeros((batch_size, main_lstm_dim*2), dtype=theano.config.floatX)
main_lstm_tm1 = sharedX(last_main_lstm, name='main_lstm_tm1')
update_list = [step_count, main_lstm_tm1]

step_count = T.switch(T.le(step_count, reset_freq), step_count + 1, 0)
s_0 = T.switch(T.or_(T.cast(T.eq(step_count, 0), 'int32'),
                     T.cast(T.eq(T.sum(main_lstm_tm1), 0.), 'int32')),
               main_lstm.get_init_state(batch_size), main_lstm_tm1)

x_shape = x.shape
x_in = x.reshape((x_shape[0]*x_shape[1], -1))
x_1_in = x_1.fprop([x_in])
x_2_in = x_2.fprop([x_1_in])
x_3_in = x_3.fprop([x_2_in])
x_4_in = x_4.fprop([x_3_in])
示例#24
0
def main(args):

    trial = int(args['trial'])
    pkl_name = 'vrnn_gmm_%d' % trial
    channel_name = 'valid_nll_upper_bound'

    data_path = args['data_path']
    save_path = args['save_path']

    monitoring_freq = int(args['monitoring_freq'])
    force_saving_freq = int(args['force_saving_freq'])
    reset_freq = int(args['reset_freq'])
    epoch = int(args['epoch'])
    batch_size = int(args['batch_size'])
    m_batch_size = int(args['m_batch_size'])
    x_dim = int(args['x_dim'])
    z_dim = int(args['z_dim'])
    rnn_dim = int(args['rnn_dim'])
    k = int(args['num_k'])
    lr = float(args['lr'])
    debug = int(args['debug'])

    print "trial no. %d" % trial
    print "batch size %d" % batch_size
    print "learning rate %f" % lr
    print "saving pkl file '%s'" % pkl_name
    print "to the save path '%s'" % save_path

    q_z_dim = 500
    p_z_dim = 500
    p_x_dim = 500
    x2s_dim = 500
    z2s_dim = 500
    target_dim = x_dim * k

    file_name = 'blizzard_unseg_tbptt'
    normal_params = np.load(data_path + file_name + '_normal.npz')
    X_mean = normal_params['X_mean']
    X_std = normal_params['X_std']

    model = Model()
    train_data = Blizzard_tbptt(name='train',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    valid_data = Blizzard_tbptt(name='valid',
                                path=data_path,
                                frame_size=x_dim,
                                file_name=file_name,
                                X_mean=X_mean,
                                X_std=X_std)

    x = train_data.theano_vars()
    m_x = valid_data.theano_vars()

    if debug:
        x.tag.test_value = np.zeros((15, batch_size, x_dim), dtype=theano.config.floatX)
        m_x.tag.test_value = np.zeros((15, m_batch_size, x_dim), dtype=theano.config.floatX)

    init_W = InitCell('rand')
    init_U = InitCell('ortho')
    init_b = InitCell('zeros')
    init_b_sig = InitCell('const', mean=0.6)

    x_1 = FullyConnectedLayer(name='x_1',
                              parent=['x_t'],
                              parent_dim=[x_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_2 = FullyConnectedLayer(name='x_2',
                              parent=['x_1'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_3 = FullyConnectedLayer(name='x_3',
                              parent=['x_2'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    x_4 = FullyConnectedLayer(name='x_4',
                              parent=['x_3'],
                              parent_dim=[x2s_dim],
                              nout=x2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_1 = FullyConnectedLayer(name='z_1',
                              parent=['z_t'],
                              parent_dim=[z_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_2 = FullyConnectedLayer(name='z_2',
                              parent=['z_1'],
                              parent_dim=[z2s_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_3 = FullyConnectedLayer(name='z_3',
                              parent=['z_2'],
                              parent_dim=[z2s_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    z_4 = FullyConnectedLayer(name='z_4',
                              parent=['z_3'],
                              parent_dim=[z2s_dim],
                              nout=z2s_dim,
                              unit='relu',
                              init_W=init_W,
                              init_b=init_b)

    rnn = LSTM(name='rnn',
               parent=['x_4', 'z_4'],
               parent_dim=[x2s_dim, z2s_dim],
               nout=rnn_dim,
               unit='tanh',
               init_W=init_W,
               init_U=init_U,
               init_b=init_b)

    phi_1 = FullyConnectedLayer(name='phi_1',
                                parent=['x_4', 's_tm1'],
                                parent_dim=[x2s_dim, rnn_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_2 = FullyConnectedLayer(name='phi_2',
                                parent=['phi_1'],
                                parent_dim=[q_z_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_3 = FullyConnectedLayer(name='phi_3',
                                parent=['phi_2'],
                                parent_dim=[q_z_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_4 = FullyConnectedLayer(name='phi_4',
                                parent=['phi_3'],
                                parent_dim=[q_z_dim],
                                nout=q_z_dim,
                                unit='relu',
                                init_W=init_W,
                                init_b=init_b)

    phi_mu = FullyConnectedLayer(name='phi_mu',
                                 parent=['phi_4'],
                                 parent_dim=[q_z_dim],
                                 nout=z_dim,
                                 unit='linear',
                                 init_W=init_W,
                                 init_b=init_b)

    phi_sig = FullyConnectedLayer(name='phi_sig',
                                  parent=['phi_4'],
                                  parent_dim=[q_z_dim],
                                  nout=z_dim,
                                  unit='softplus',
                                  cons=1e-4,
                                  init_W=init_W,
                                  init_b=init_b_sig)

    prior_1 = FullyConnectedLayer(name='prior_1',
                                  parent=['s_tm1'],
                                  parent_dim=[rnn_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_2 = FullyConnectedLayer(name='prior_2',
                                  parent=['prior_1'],
                                  parent_dim=[p_z_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_3 = FullyConnectedLayer(name='prior_3',
                                  parent=['prior_2'],
                                  parent_dim=[p_z_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_4 = FullyConnectedLayer(name='prior_4',
                                  parent=['prior_3'],
                                  parent_dim=[p_z_dim],
                                  nout=p_z_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    prior_mu = FullyConnectedLayer(name='prior_mu',
                                   parent=['prior_4'],
                                   parent_dim=[p_z_dim],
                                   nout=z_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    prior_sig = FullyConnectedLayer(name='prior_sig',
                                    parent=['prior_4'],
                                    parent_dim=[p_z_dim],
                                    nout=z_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    theta_1 = FullyConnectedLayer(name='theta_1',
                                  parent=['z_4', 's_tm1'],
                                  parent_dim=[z2s_dim, rnn_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_2 = FullyConnectedLayer(name='theta_2',
                                  parent=['theta_1'],
                                  parent_dim=[p_x_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_3 = FullyConnectedLayer(name='theta_3',
                                  parent=['theta_2'],
                                  parent_dim=[p_x_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_4 = FullyConnectedLayer(name='theta_4',
                                  parent=['theta_3'],
                                  parent_dim=[p_x_dim],
                                  nout=p_x_dim,
                                  unit='relu',
                                  init_W=init_W,
                                  init_b=init_b)

    theta_mu = FullyConnectedLayer(name='theta_mu',
                                   parent=['theta_4'],
                                   parent_dim=[p_x_dim],
                                   nout=target_dim,
                                   unit='linear',
                                   init_W=init_W,
                                   init_b=init_b)

    theta_sig = FullyConnectedLayer(name='theta_sig',
                                    parent=['theta_4'],
                                    parent_dim=[p_x_dim],
                                    nout=target_dim,
                                    unit='softplus',
                                    cons=1e-4,
                                    init_W=init_W,
                                    init_b=init_b_sig)

    coeff = FullyConnectedLayer(name='coeff',
                                parent=['theta_4'],
                                parent_dim=[p_x_dim],
                                nout=k,
                                unit='softmax',
                                init_W=init_W,
                                init_b=init_b)

    nodes = [rnn,
             x_1, x_2, x_3, x_4,
             z_1, z_2, z_3, z_4,
             phi_1, phi_2, phi_3, phi_4, phi_mu, phi_sig,
             prior_1, prior_2, prior_3, prior_4, prior_mu, prior_sig,
             theta_1, theta_2, theta_3, theta_4, theta_mu, theta_sig, coeff]

    params = OrderedDict()
    for node in nodes:
        if node.initialize() is not None:
            params.update(node.initialize())
    params = init_tparams(params)

    step_count = sharedX(0, name='step_count')
    last_rnn = np.zeros((batch_size, rnn_dim*2), dtype=theano.config.floatX)
    rnn_tm1 = sharedX(last_rnn, name='rnn_tm1')
    shared_updates = OrderedDict()
    shared_updates[step_count] = step_count + 1

    s_0 = T.switch(T.eq(T.mod(step_count, reset_freq), 0),
                   rnn.get_init_state(batch_size), rnn_tm1)

    x_shape = x.shape
    x_in = x.reshape((x_shape[0]*x_shape[1], -1))
    x_1_in = x_1.fprop([x_in], params)
    x_2_in = x_2.fprop([x_1_in], params)
    x_3_in = x_3.fprop([x_2_in], params)
    x_4_in = x_4.fprop([x_3_in], params)
    x_4_in = x_4_in.reshape((x_shape[0], x_shape[1], -1))


    def inner_fn(x_t, s_tm1):

        phi_1_t = phi_1.fprop([x_t, s_tm1], params)
        phi_2_t = phi_2.fprop([phi_1_t], params)
        phi_3_t = phi_3.fprop([phi_2_t], params)
        phi_4_t = phi_4.fprop([phi_3_t], params)
        phi_mu_t = phi_mu.fprop([phi_4_t], params)
        phi_sig_t = phi_sig.fprop([phi_4_t], params)

        prior_1_t = prior_1.fprop([s_tm1], params)
        prior_2_t = prior_2.fprop([prior_1_t], params)
        prior_3_t = prior_3.fprop([prior_2_t], params)
        prior_4_t = prior_4.fprop([prior_3_t], params)
        prior_mu_t = prior_mu.fprop([prior_4_t], params)
        prior_sig_t = prior_sig.fprop([prior_4_t], params)

        z_t = Gaussian_sample(phi_mu_t, phi_sig_t)

        z_1_t = z_1.fprop([z_t], params)
        z_2_t = z_2.fprop([z_1_t], params)
        z_3_t = z_3.fprop([z_2_t], params)
        z_4_t = z_4.fprop([z_3_t], params)

        s_t = rnn.fprop([[x_t, z_4_t], [s_tm1]], params)

        return s_t, phi_mu_t, phi_sig_t, prior_mu_t, prior_sig_t, z_4_t

    ((s_temp, phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp, z_4_temp), updates) =\
        theano.scan(fn=inner_fn,
                    sequences=[x_4_in],
                    outputs_info=[s_0, None, None, None, None, None])

    for k, v in updates.iteritems():
        k.default_update = v

    shared_updates[rnn_tm1] = s_temp[-1]
    s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)
    theta_1_temp = theta_1.fprop([z_4_temp, s_temp], params)
    theta_2_temp = theta_2.fprop([theta_1_temp], params)
    theta_3_temp = theta_3.fprop([theta_2_temp], params)
    theta_4_temp = theta_4.fprop([theta_3_temp], params)
    theta_mu_temp = theta_mu.fprop([theta_4_temp], params)
    theta_sig_temp = theta_sig.fprop([theta_4_temp], params)
    coeff_temp = coeff.fprop([theta_4_temp], params)

    kl_temp = KLGaussianGaussian(phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp)

    x_shape = x.shape
    x_in = x.reshape((x_shape[0]*x_shape[1], -1))
    theta_mu_in = theta_mu_temp.reshape((x_shape[0]*x_shape[1], -1))
    theta_sig_in = theta_sig_temp.reshape((x_shape[0]*x_shape[1], -1))
    coeff_in = coeff_temp.reshape((x_shape[0]*x_shape[1], -1))

    recon = GMM(x_in, theta_mu_in, theta_sig_in, coeff_in)
    recon_term = recon.mean()
    kl_term = kl_temp.mean()
    nll_upper_bound = recon_term + kl_term
    nll_upper_bound.name = 'nll_upper_bound'

    m_x_1_temp = x_1.fprop([m_x], params)
    m_x_2_temp = x_2.fprop([m_x_1_temp], params)
    m_x_3_temp = x_3.fprop([m_x_2_temp], params)
    m_x_4_temp = x_4.fprop([m_x_3_temp], params)

    m_s_0 = rnn.get_init_state(m_batch_size)

    ((m_s_temp, m_phi_mu_temp, m_phi_sig_temp, m_prior_mu_temp, m_prior_sig_temp, m_z_4_temp), m_updates) =\
        theano.scan(fn=inner_fn,
                    sequences=[m_x_4_temp],
                    outputs_info=[m_s_0, None, None, None, None, None])

    for k, v in m_updates.iteritems():
        k.default_update = v

    m_s_temp = concatenate([m_s_0[None, :, :], m_s_temp[:-1]], axis=0)
    m_theta_1_temp = theta_1.fprop([m_z_4_temp, m_s_temp], params)
    m_theta_2_temp = theta_2.fprop([m_theta_1_temp], params)
    m_theta_3_temp = theta_3.fprop([m_theta_2_temp], params)
    m_theta_4_temp = theta_4.fprop([m_theta_3_temp], params)
    m_theta_mu_temp = theta_mu.fprop([m_theta_4_temp], params)
    m_theta_sig_temp = theta_sig.fprop([m_theta_4_temp], params)
    m_coeff_temp = coeff.fprop([m_theta_4_temp], params)

    m_kl_temp = KLGaussianGaussian(m_phi_mu_temp, m_phi_sig_temp, m_prior_mu_temp, m_prior_sig_temp)

    m_x_shape = m_x.shape
    m_x_in = m_x.reshape((m_x_shape[0]*m_x_shape[1], -1))
    m_theta_mu_in = m_theta_mu_temp.reshape((m_x_shape[0]*m_x_shape[1], -1))
    m_theta_sig_in = m_theta_sig_temp.reshape((m_x_shape[0]*m_x_shape[1], -1))
    m_coeff_in = m_coeff_temp.reshape((m_x_shape[0]*m_x_shape[1], -1))

    m_recon = GMM(m_x_in, m_theta_mu_in, m_theta_sig_in, m_coeff_in)
    m_recon_term = m_recon.mean()
    m_kl_term = m_kl_temp.mean()
    m_nll_upper_bound = m_recon_term + m_kl_term
    m_nll_upper_bound.name = 'nll_upper_bound'
    m_recon_term.name = 'recon_term'
    m_kl_term.name = 'kl_term'

    max_x = m_x.max()
    mean_x = m_x.mean()
    min_x = m_x.min()
    max_x.name = 'max_x'
    mean_x.name = 'mean_x'
    min_x.name = 'min_x'

    max_theta_mu = m_theta_mu_in.max()
    mean_theta_mu = m_theta_mu_in.mean()
    min_theta_mu = m_theta_mu_in.min()
    max_theta_mu.name = 'max_theta_mu'
    mean_theta_mu.name = 'mean_theta_mu'
    min_theta_mu.name = 'min_theta_mu'

    max_theta_sig = m_theta_sig_in.max()
    mean_theta_sig = m_theta_sig_in.mean()
    min_theta_sig = m_theta_sig_in.min()
    max_theta_sig.name = 'max_theta_sig'
    mean_theta_sig.name = 'mean_theta_sig'
    min_theta_sig.name = 'min_theta_sig'

    max_phi_sig = m_phi_sig_temp.max()
    mean_phi_sig = m_phi_sig_temp.mean()
    min_phi_sig = m_phi_sig_temp.min()
    max_phi_sig.name = 'max_phi_sig'
    mean_phi_sig.name = 'mean_phi_sig'
    min_phi_sig.name = 'min_phi_sig'

    max_prior_sig = m_prior_sig_temp.max()
    mean_prior_sig = m_prior_sig_temp.mean()
    min_prior_sig = m_prior_sig_temp.min()
    max_prior_sig.name = 'max_prior_sig'
    mean_prior_sig.name = 'mean_prior_sig'
    min_prior_sig.name = 'min_prior_sig'

    model.inputs = [x]
    model.params = params
    model.nodes = nodes
    model.set_updates(shared_updates)

    optimizer = Adam(
        lr=lr
    )

    monitor_fn = theano.function(inputs=[m_x],
                                 outputs=[m_nll_upper_bound, m_recon_term, m_kl_term,
                                          max_phi_sig, mean_phi_sig, min_phi_sig,
                                          max_prior_sig, mean_prior_sig, min_prior_sig,
                                          max_theta_sig, mean_theta_sig, min_theta_sig,
                                          max_x, mean_x, min_x,
                                          max_theta_mu, mean_theta_mu, min_theta_mu],
                                 on_unused_input='ignore')

    extension = [
        GradientClipping(batch_size=batch_size, check_nan=1),
        EpochCount(epoch),
        Monitoring(freq=monitoring_freq,
                   monitor_fn=monitor_fn,
                   ddout=[m_nll_upper_bound, m_recon_term, m_kl_term,
                          max_phi_sig, mean_phi_sig, min_phi_sig,
                          max_prior_sig, mean_prior_sig, min_prior_sig,
                          max_theta_sig, mean_theta_sig, min_theta_sig,
                          max_x, mean_x, min_x,
                          max_theta_mu, mean_theta_mu, min_theta_mu],
                   data=[Iterator(train_data, m_batch_size, start=0, end=112640),
                         Iterator(valid_data, m_batch_size, start=2040064, end=2152704)]),
        Picklize(freq=monitoring_freq, force_save_freq=force_saving_freq, path=save_path),
        EarlyStopping(freq=monitoring_freq, force_save_freq=force_saving_freq, path=save_path, channel=channel_name),
        WeightNorm()
    ]

    mainloop = Training(
        name=pkl_name,
        data=Iterator(train_data, batch_size, start=0, end=2040064),
        model=model,
        optimizer=optimizer,
        cost=nll_upper_bound,
        outputs=[nll_upper_bound],
        extension=extension
    )
    mainloop.run()
示例#25
0
文件: layer.py 项目: Beronx86/cle
 def __init__(self, rho=0.5, eps=1e-6, **kwargs):
     super(BatchNormLayer, self).__init__(**kwargs)
     self.rho = rho
     self.eps = eps
     self.mu = sharedX(InitCell('zeros').get(self.nout), name='mu_'+self.name)
     self.sigma = sharedX(InitCell('ones').get(self.nout), name='sigma_'+self.name)
示例#26
0
 def getX(self, shape, name=None):
     return sharedX(self.init_param(shape), name)
示例#27
0
output = FullyConnectedLayer(name='output',
                             parent=['h1', 'h2', 'h3'],
                             parent_dim=[200, 200, 200],
                             nout=205,
                             unit='softmax',
                             init_W=init_W,
                             init_b=init_b)

nodes = [h1, h2, h3, output]

for node in nodes:
    node.initialize()

params = flatten([node.get_params().values() for node in nodes])

step_count = sharedX(0, name='step_count')
last_h = np.zeros((batch_size, 400), dtype=np.float32)
h1_tm1 = sharedX(last_h, name='h1_tm1')
h2_tm1 = sharedX(last_h, name='h2_tm1')
h3_tm1 = sharedX(last_h, name='h3_tm1')
update_list = [step_count, h1_tm1, h2_tm1, h3_tm1]

step_count = T.switch(T.le(step_count, reset_freq), step_count + 1, 0)

s1_0 = T.switch(
    T.or_(T.cast(T.eq(step_count, 0), 'int32'),
          T.cast(T.eq(T.sum(h1_tm1), 0.), 'int32')), h1.get_init_state(),
    h1_tm1)
s2_0 = T.switch(
    T.or_(T.cast(T.eq(step_count, 0), 'int32'),
          T.cast(T.eq(T.sum(h2_tm1), 0.), 'int32')), h2.get_init_state(),
示例#28
0
文件: enwiki.py 项目: Beronx86/cle
output = FullyConnectedLayer(name='output',
                             parent=['h1', 'h2', 'h3'],
                             parent_dim=[200, 200, 200],
                             nout=205,
                             unit='softmax',
                             init_W=init_W,
                             init_b=init_b)

nodes = [h1, h2, h3, output]

for node in nodes:
    node.initialize()

params = flatten([node.get_params().values() for node in nodes])

step_count = sharedX(0, name='step_count')
last_h = np.zeros((batch_size, 400), dtype=np.float32)
h1_tm1 = sharedX(last_h, name='h1_tm1')
h2_tm1 = sharedX(last_h, name='h2_tm1')
h3_tm1 = sharedX(last_h, name='h3_tm1')
update_list = [step_count, h1_tm1, h2_tm1, h3_tm1]

step_count = T.switch(T.le(step_count, reset_freq),
                      step_count + 1, 0)

s1_0 = T.switch(T.or_(T.cast(T.eq(step_count, 0), 'int32'),
                      T.cast(T.eq(T.sum(h1_tm1), 0.), 'int32')),
                h1.get_init_state(), h1_tm1)
s2_0 = T.switch(T.or_(T.cast(T.eq(step_count, 0), 'int32'),
                      T.cast(T.eq(T.sum(h2_tm1), 0.), 'int32')),
                h2.get_init_state(), h2_tm1)
示例#29
0
def main(args):

    trial = int(args["trial"])
    pkl_name = "vrnn_gauss_%d" % trial
    channel_name = "valid_nll_upper_bound"

    data_path = args["data_path"]
    save_path = args["save_path"]
    data_path = os.path.expanduser(args["data_path"])
    save_path = os.path.expanduser(args["save_path"])
    monitoring_freq = int(args["monitoring_freq"])
    force_saving_freq = int(args["force_saving_freq"])
    reset_freq = int(args["reset_freq"])
    epoch = int(args["epoch"])
    batch_size = int(args["batch_size"])
    m_batch_size = int(args["m_batch_size"])
    x_dim = int(args["x_dim"])
    z_dim = int(args["z_dim"])
    rnn_dim = int(args["rnn_dim"])
    lr = float(args["lr"])
    debug = int(args["debug"])

    print "trial no. %d" % trial
    print "batch size %d" % batch_size
    print "learning rate %f" % lr
    print "saving pkl file '%s'" % pkl_name
    print "to the save path '%s'" % save_path

    q_z_dim = 500
    p_z_dim = 500
    p_x_dim = 600
    x2s_dim = 600
    z2s_dim = 500
    target_dim = x_dim

    file_name = "blizzard_tbptt"
    normal_params = np.load(data_path + file_name + "_normal.npz")
    X_mean = normal_params["X_mean"]
    X_std = normal_params["X_std"]

    model = Model()
    train_data = Blizzard_tbptt(
        name="train", path=data_path, frame_size=x_dim, file_name=file_name, X_mean=X_mean, X_std=X_std
    )

    valid_data = Blizzard_tbptt(
        name="valid", path=data_path, frame_size=x_dim, file_name=file_name, X_mean=X_mean, X_std=X_std
    )

    x = train_data.theano_vars()
    m_x = valid_data.theano_vars()

    if debug:
        x.tag.test_value = np.zeros((15, batch_size, x_dim), dtype=theano.config.floatX)
        m_x.tag.test_value = np.zeros((15, m_batch_size, x_dim), dtype=theano.config.floatX)

    init_W = InitCell("rand")
    init_U = InitCell("ortho")
    init_b = InitCell("zeros")
    init_b_sig = InitCell("const", mean=0.6)

    x_1 = FullyConnectedLayer(
        name="x_1", parent=["x_t"], parent_dim=[x_dim], nout=x2s_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    x_2 = FullyConnectedLayer(
        name="x_2", parent=["x_1"], parent_dim=[x2s_dim], nout=x2s_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    x_3 = FullyConnectedLayer(
        name="x_3", parent=["x_2"], parent_dim=[x2s_dim], nout=x2s_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    x_4 = FullyConnectedLayer(
        name="x_4", parent=["x_3"], parent_dim=[x2s_dim], nout=x2s_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    z_1 = FullyConnectedLayer(
        name="z_1", parent=["z_t"], parent_dim=[z_dim], nout=z2s_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    z_2 = FullyConnectedLayer(
        name="z_2", parent=["z_1"], parent_dim=[z2s_dim], nout=z2s_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    z_3 = FullyConnectedLayer(
        name="z_3", parent=["z_2"], parent_dim=[z2s_dim], nout=z2s_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    z_4 = FullyConnectedLayer(
        name="z_4", parent=["z_3"], parent_dim=[z2s_dim], nout=z2s_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    rnn = LSTM(
        name="rnn",
        parent=["x_4", "z_4"],
        parent_dim=[x2s_dim, z2s_dim],
        nout=rnn_dim,
        unit="tanh",
        init_W=init_W,
        init_U=init_U,
        init_b=init_b,
    )

    phi_1 = FullyConnectedLayer(
        name="phi_1",
        parent=["x_4", "s_tm1"],
        parent_dim=[x2s_dim, rnn_dim],
        nout=q_z_dim,
        unit="relu",
        init_W=init_W,
        init_b=init_b,
    )

    phi_2 = FullyConnectedLayer(
        name="phi_2", parent=["phi_1"], parent_dim=[q_z_dim], nout=q_z_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    phi_3 = FullyConnectedLayer(
        name="phi_3", parent=["phi_2"], parent_dim=[q_z_dim], nout=q_z_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    phi_4 = FullyConnectedLayer(
        name="phi_4", parent=["phi_3"], parent_dim=[q_z_dim], nout=q_z_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    phi_mu = FullyConnectedLayer(
        name="phi_mu", parent=["phi_4"], parent_dim=[q_z_dim], nout=z_dim, unit="linear", init_W=init_W, init_b=init_b
    )

    phi_sig = FullyConnectedLayer(
        name="phi_sig",
        parent=["phi_4"],
        parent_dim=[q_z_dim],
        nout=z_dim,
        unit="softplus",
        cons=1e-4,
        init_W=init_W,
        init_b=init_b_sig,
    )

    prior_1 = FullyConnectedLayer(
        name="prior_1", parent=["s_tm1"], parent_dim=[rnn_dim], nout=p_z_dim, unit="relu", init_W=init_W, init_b=init_b
    )

    prior_2 = FullyConnectedLayer(
        name="prior_2",
        parent=["prior_1"],
        parent_dim=[p_z_dim],
        nout=p_z_dim,
        unit="relu",
        init_W=init_W,
        init_b=init_b,
    )

    prior_3 = FullyConnectedLayer(
        name="prior_3",
        parent=["prior_2"],
        parent_dim=[p_z_dim],
        nout=p_z_dim,
        unit="relu",
        init_W=init_W,
        init_b=init_b,
    )

    prior_4 = FullyConnectedLayer(
        name="prior_4",
        parent=["prior_3"],
        parent_dim=[p_z_dim],
        nout=p_z_dim,
        unit="relu",
        init_W=init_W,
        init_b=init_b,
    )

    prior_mu = FullyConnectedLayer(
        name="prior_mu",
        parent=["prior_4"],
        parent_dim=[p_z_dim],
        nout=z_dim,
        unit="linear",
        init_W=init_W,
        init_b=init_b,
    )

    prior_sig = FullyConnectedLayer(
        name="prior_sig",
        parent=["prior_4"],
        parent_dim=[p_z_dim],
        nout=z_dim,
        unit="softplus",
        cons=1e-4,
        init_W=init_W,
        init_b=init_b_sig,
    )

    theta_1 = FullyConnectedLayer(
        name="theta_1",
        parent=["z_4", "s_tm1"],
        parent_dim=[z2s_dim, rnn_dim],
        nout=p_x_dim,
        unit="relu",
        init_W=init_W,
        init_b=init_b,
    )

    theta_2 = FullyConnectedLayer(
        name="theta_2",
        parent=["theta_1"],
        parent_dim=[p_x_dim],
        nout=p_x_dim,
        unit="relu",
        init_W=init_W,
        init_b=init_b,
    )

    theta_3 = FullyConnectedLayer(
        name="theta_3",
        parent=["theta_2"],
        parent_dim=[p_x_dim],
        nout=p_x_dim,
        unit="relu",
        init_W=init_W,
        init_b=init_b,
    )

    theta_4 = FullyConnectedLayer(
        name="theta_4",
        parent=["theta_3"],
        parent_dim=[p_x_dim],
        nout=p_x_dim,
        unit="relu",
        init_W=init_W,
        init_b=init_b,
    )

    theta_mu = FullyConnectedLayer(
        name="theta_mu",
        parent=["theta_4"],
        parent_dim=[p_x_dim],
        nout=target_dim,
        unit="linear",
        init_W=init_W,
        init_b=init_b,
    )

    theta_sig = FullyConnectedLayer(
        name="theta_sig",
        parent=["theta_4"],
        parent_dim=[p_x_dim],
        nout=target_dim,
        unit="softplus",
        cons=1e-4,
        init_W=init_W,
        init_b=init_b_sig,
    )

    nodes = [
        rnn,
        x_1,
        x_2,
        x_3,
        x_4,
        z_1,
        z_2,
        z_3,
        z_4,
        phi_1,
        phi_2,
        phi_3,
        phi_4,
        phi_mu,
        phi_sig,
        prior_1,
        prior_2,
        prior_3,
        prior_4,
        prior_mu,
        prior_sig,
        theta_1,
        theta_2,
        theta_3,
        theta_4,
        theta_mu,
        theta_sig,
    ]

    params = OrderedDict()

    for node in nodes:
        if node.initialize() is not None:
            params.update(node.initialize())

    params = init_tparams(params)

    step_count = sharedX(0, name="step_count")
    last_rnn = np.zeros((batch_size, rnn_dim * 2), dtype=theano.config.floatX)
    rnn_tm1 = sharedX(last_rnn, name="rnn_tm1")
    shared_updates = OrderedDict()
    shared_updates[step_count] = step_count + 1

    s_0 = T.switch(T.eq(T.mod(step_count, reset_freq), 0), rnn.get_init_state(batch_size), rnn_tm1)

    x_1_temp = x_1.fprop([x], params)
    x_2_temp = x_2.fprop([x_1_temp], params)
    x_3_temp = x_3.fprop([x_2_temp], params)
    x_4_temp = x_4.fprop([x_3_temp], params)

    def inner_fn(x_t, s_tm1):

        phi_1_t = phi_1.fprop([x_t, s_tm1], params)
        phi_2_t = phi_2.fprop([phi_1_t], params)
        phi_3_t = phi_3.fprop([phi_2_t], params)
        phi_4_t = phi_4.fprop([phi_3_t], params)
        phi_mu_t = phi_mu.fprop([phi_4_t], params)
        phi_sig_t = phi_sig.fprop([phi_4_t], params)

        prior_1_t = prior_1.fprop([s_tm1], params)
        prior_2_t = prior_2.fprop([prior_1_t], params)
        prior_3_t = prior_3.fprop([prior_2_t], params)
        prior_4_t = prior_4.fprop([prior_3_t], params)
        prior_mu_t = prior_mu.fprop([prior_4_t], params)
        prior_sig_t = prior_sig.fprop([prior_4_t], params)

        z_t = Gaussian_sample(phi_mu_t, phi_sig_t)

        z_1_t = z_1.fprop([z_t], params)
        z_2_t = z_2.fprop([z_1_t], params)
        z_3_t = z_3.fprop([z_2_t], params)
        z_4_t = z_4.fprop([z_3_t], params)

        s_t = rnn.fprop([[x_t, z_4_t], [s_tm1]], params)

        return s_t, phi_mu_t, phi_sig_t, prior_mu_t, prior_sig_t, z_4_t, z_t

    ((s_temp, phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp, z_4_temp, z_t), updates) = theano.scan(
        fn=inner_fn, sequences=[x_4_temp], outputs_info=[s_0, None, None, None, None, None, None]
    )

    for k, v in updates.iteritems():
        k.default_update = v

    shared_updates[rnn_tm1] = s_temp[-1]
    s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)
    theta_1_temp = theta_1.fprop([z_4_temp, s_temp], params)
    theta_2_temp = theta_2.fprop([theta_1_temp], params)
    theta_3_temp = theta_3.fprop([theta_2_temp], params)
    theta_4_temp = theta_4.fprop([theta_3_temp], params)
    theta_mu_temp = theta_mu.fprop([theta_4_temp], params)
    theta_sig_temp = theta_sig.fprop([theta_4_temp], params)

    kl_temp = KLGaussianGaussian(phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp)

    #### recon = Gaussian(x, theta_mu_temp, theta_sig_temp)
    recon = Gaussian(x, theta_mu_temp, theta_sig_temp) - Gaussian(z_t, phi_mu_temp, phi_sig_temp)
    recon += Gaussian(z_t, prior_mu_temp, prior_sig_temp)
    recon_term = recon.mean() / 5.0
    kl_term = kl_temp.mean()
    ##### nll_upper_bound = recon_term + kl_term
    nll_upper_bound = recon_term
    nll_upper_bound.name = "nll_upper_bound"

    m_x_1_temp = x_1.fprop([m_x], params)
    m_x_2_temp = x_2.fprop([m_x_1_temp], params)
    m_x_3_temp = x_3.fprop([m_x_2_temp], params)
    m_x_4_temp = x_4.fprop([m_x_3_temp], params)

    m_s_0 = rnn.get_init_state(m_batch_size)

    (
        (m_s_temp, m_phi_mu_temp, m_phi_sig_temp, m_prior_mu_temp, m_prior_sig_temp, m_z_4_temp, m_z_t),
        m_updates,
    ) = theano.scan(fn=inner_fn, sequences=[m_x_4_temp], outputs_info=[m_s_0, None, None, None, None, None, None])

    for k, v in m_updates.iteritems():
        k.default_update = v

    m_s_temp = concatenate([m_s_0[None, :, :], m_s_temp[:-1]], axis=0)
    m_theta_1_temp = theta_1.fprop([m_z_4_temp, m_s_temp], params)
    m_theta_2_temp = theta_2.fprop([m_theta_1_temp], params)
    m_theta_3_temp = theta_3.fprop([m_theta_2_temp], params)
    m_theta_4_temp = theta_4.fprop([m_theta_3_temp], params)
    m_theta_mu_temp = theta_mu.fprop([m_theta_4_temp], params)
    m_theta_sig_temp = theta_sig.fprop([m_theta_4_temp], params)

    m_kl_temp = KLGaussianGaussian(m_phi_mu_temp, m_phi_sig_temp, m_prior_mu_temp, m_prior_sig_temp)

    m_recon = Gaussian(m_x, m_theta_mu_temp, m_theta_sig_temp)
    m_recon_term = m_recon.mean()
    m_kl_term = m_kl_temp.mean()
    m_nll_upper_bound = m_recon_term + m_kl_term
    m_nll_upper_bound.name = "nll_upper_bound"
    m_recon_term.name = "recon_term"
    m_kl_term.name = "kl_term"

    max_x = m_x.max()
    mean_x = m_x.mean()
    min_x = m_x.min()
    max_x.name = "max_x"
    mean_x.name = "mean_x"
    min_x.name = "min_x"

    max_theta_mu = m_theta_mu_temp.max()
    mean_theta_mu = m_theta_mu_temp.mean()
    min_theta_mu = m_theta_mu_temp.min()
    max_theta_mu.name = "max_theta_mu"
    mean_theta_mu.name = "mean_theta_mu"
    min_theta_mu.name = "min_theta_mu"

    max_theta_sig = m_theta_sig_temp.max()
    mean_theta_sig = m_theta_sig_temp.mean()
    min_theta_sig = m_theta_sig_temp.min()
    max_theta_sig.name = "max_theta_sig"
    mean_theta_sig.name = "mean_theta_sig"
    min_theta_sig.name = "min_theta_sig"

    max_phi_sig = m_phi_sig_temp.max()
    mean_phi_sig = m_phi_sig_temp.mean()
    min_phi_sig = m_phi_sig_temp.min()
    max_phi_sig.name = "max_phi_sig"
    mean_phi_sig.name = "mean_phi_sig"
    min_phi_sig.name = "min_phi_sig"

    max_prior_sig = m_prior_sig_temp.max()
    mean_prior_sig = m_prior_sig_temp.mean()
    min_prior_sig = m_prior_sig_temp.min()
    max_prior_sig.name = "max_prior_sig"
    mean_prior_sig.name = "mean_prior_sig"
    min_prior_sig.name = "min_prior_sig"

    model.inputs = [x]
    model.params = params
    model.nodes = nodes
    model.set_updates(shared_updates)

    optimizer = Adam(lr=lr)

    monitor_fn = theano.function(
        inputs=[m_x],
        outputs=[
            m_nll_upper_bound,
            m_recon_term,
            m_kl_term,
            max_phi_sig,
            mean_phi_sig,
            min_phi_sig,
            max_prior_sig,
            mean_prior_sig,
            min_prior_sig,
            max_theta_sig,
            mean_theta_sig,
            min_theta_sig,
            max_x,
            mean_x,
            min_x,
            max_theta_mu,
            mean_theta_mu,
            min_theta_mu,
        ],
        on_unused_input="ignore",
    )

    extension = [
        GradientClipping(batch_size=batch_size, check_nan=1),
        EpochCount(epoch),
        Monitoring(
            freq=monitoring_freq,
            monitor_fn=monitor_fn,
            ddout=[
                m_nll_upper_bound,
                m_recon_term,
                m_kl_term,
                max_phi_sig,
                mean_phi_sig,
                min_phi_sig,
                max_prior_sig,
                mean_prior_sig,
                min_prior_sig,
                max_theta_sig,
                mean_theta_sig,
                min_theta_sig,
                max_x,
                mean_x,
                min_x,
                max_theta_mu,
                mean_theta_mu,
                min_theta_mu,
            ],
            data=[
                Iterator(train_data, m_batch_size, start=0, end=112640),
                Iterator(valid_data, m_batch_size, start=2040064, end=2152704),
            ],
        ),
        Picklize(freq=monitoring_freq, force_save_freq=force_saving_freq, path=save_path),
        EarlyStopping(freq=monitoring_freq, force_save_freq=force_saving_freq, path=save_path, channel=channel_name),
        WeightNorm(),
    ]

    mainloop = Training(
        name=pkl_name,
        data=KIter(train_data, batch_size, start=0, end=2040064),
        model=model,
        optimizer=optimizer,
        cost=nll_upper_bound,
        outputs=[nll_upper_bound],
        extension=extension,
    )
    mainloop.run()