Exemplo n.º 1
0
def rnn(images, mask_true, num_layers, num_hidden, filter_size, stride=1,
        seq_length=20, input_length=10, tln=True):

    gen_images = []
    lstm = []
    cell = []
    hidden = []
    shape = images.get_shape().as_list()
    output_channels = shape[-1]

    for i in xrange(num_layers):
        if i == 0:
            num_hidden_in = num_hidden[num_layers-1]
        else:
            num_hidden_in = num_hidden[i-1]
        new_cell = cslstm('lstm_'+str(i+1),
                          filter_size,
                          num_hidden_in,
                          num_hidden[i],
                          shape,
                          tln=tln)
        lstm.append(new_cell)
        cell.append(None)
        hidden.append(None)

    gradient_highway = ghu('highway', filter_size, num_hidden[0], tln=tln)

    mem = None
    z_t = None

    for t in xrange(seq_length-1):
        reuse = bool(gen_images)
        with tf.variable_scope('predrnn_pp', reuse=reuse):
            if t < input_length:
                inputs = images[:,t]
            else:
                inputs = mask_true[:,t-10]*images[:,t] + (1-mask_true[:,t-10])*x_gen

            hidden[0], cell[0], mem = lstm[0](inputs, hidden[0], cell[0], mem)
            z_t = gradient_highway(hidden[0], z_t)
            hidden[1], cell[1], mem = lstm[1](z_t, hidden[1], cell[1], mem)

            for i in xrange(2, num_layers):
                hidden[i], cell[i], mem = lstm[i](hidden[i-1], hidden[i], cell[i], mem)

            x_gen = tf.layers.conv2d(inputs=hidden[num_layers-1],
                                     filters=output_channels,
                                     kernel_size=1,
                                     strides=1,
                                     padding='same',
                                     name="back_to_pixel")
            gen_images.append(x_gen)

    gen_images = tf.stack(gen_images)
    # [batch_size, seq_length, height, width, channels]
    gen_images = tf.transpose(gen_images, [1,0,2,3,4])
    loss = tf.nn.l2_loss(gen_images - images[:,1:])
    #loss += tf.reduce_sum(tf.abs(gen_images - images[:,1:]))
    return [gen_images, loss]
Exemplo n.º 2
0
 def __init__(self, batch, img_width, img_height, channels, filters, kernel_size, num_layers, 
              seqlength, inputlength, patch_size):
     super(predrnn, self).__init__()
     self.num_layers = num_layers
     self.seqlength = seqlength
     self.inputlength = inputlength
     self.filters = filters
     self.batch = batch
     self.img_width = img_width
     self.img_height = img_height
     self.channels = channels
     self.patch_size = patch_size
     self.kernel_size = kernel_size
     self.layer1 = cslstm(batch, img_width, img_height, channels, filters[0], kernel_size)
     self.layer2 = ghu(filters[0], kernel_size)
     self.layer3 = cslstm(batch, img_width, img_height, channels, filters[1], kernel_size)
     self.layer4 = cslstm(batch, img_width, img_height, channels, filters[2], kernel_size)
     self.layer5 = cslstm(batch, img_width, img_height, channels, filters[3], kernel_size)
     # modify here
     self.layer6 = layers.Conv2D(patch_size*patch_size, kernel_size=kernel_size, padding='same')
Exemplo n.º 3
0
def rnn(images,
        images_bw,
        mask_true,
        num_layers,
        num_hidden,
        filter_size,
        stride=1,
        seq_length=11,
        input_length=5,
        tln=True):

    gen_images = []
    lstm_fw = []
    lstm_bw = []
    cell_fw = []
    cell_bw = []
    hidden_fw = []
    hidden_bw = []
    shape = images.get_shape().as_list()
    output_channels = shape[-1]
    # Time Machine (put memory and hidden per layer)
    tm_hidden_fw = [[None for i in range(seq_length)] for k in range(4)]
    tm_hidden_bw = [[None for i in range(seq_length)] for k in range(4)]
    tm_mem_fw = [[None for i in range(seq_length)] for k in range(4)]
    tm_mem_bw = [[None for i in range(seq_length)] for k in range(4)]

    ## Create causal lstm unit
    # Create forward causal lstm unit
    for i in range(num_layers):
        if i == 0:
            num_hidden_in = num_hidden[num_layers - 1]
        else:
            num_hidden_in = num_hidden[i - 1]
        new_cell = cslstm('lstm_fw_' + str(i + 1),
                          filter_size,
                          num_hidden_in,
                          num_hidden[i],
                          shape,
                          tln=tln)
        lstm_fw.append(new_cell)
        cell_fw.append(None)
        hidden_fw.append(None)
    # Create backward causal lstm unit
    for i in range(num_layers):
        if i == 0:
            num_hidden_in = num_hidden[num_layers - 1]
        else:
            num_hidden_in = num_hidden[i - 1]
        new_cell = cslstm('lstm_bw_' + str(i + 1),
                          filter_size,
                          num_hidden_in,
                          num_hidden[i],
                          shape,
                          tln=tln)
        lstm_bw.append(new_cell)
        cell_bw.append(None)
        hidden_bw.append(None)

    ## Create GHU unit
    # Create forward and backward GHU unit
    gradient_highway_fw = ghu('highway_fw',
                              filter_size,
                              num_hidden[0],
                              tln=tln)
    gradient_highway_bw = ghu('highway_bw',
                              filter_size,
                              num_hidden[0],
                              tln=tln)

    ## Create lstm memory output and GHU output
    # Create forward memory output and GHU output
    mem_fw = None
    z_t_fw = None
    # Create backward memory output and GHU output
    mem_bw = None
    z_t_bw = None

    print("seq_length:{}".format(seq_length))

    # Layer 1
    for t in range(seq_length):
        print("Layer 1")
        print("t:{}".format(t))

        #         reuse = bool(gen_images)
        # Layer 1 Forward
        with tf.variable_scope('bi_cslstm_l1', reuse=tf.AUTO_REUSE):
            # accroding mask_true replace with random noise
            inputs_fw = mask_true[:, t] * images[:, t] + (
                1 - mask_true[:, t]) * sample_Z((1 - mask_true[:, t]))

            tf.summary.image('masktrue_fw',
                             reshape_patch_back_gen(mask_true[:, t], 4), 11)
            tf.summary.image('input_fw', reshape_patch_back_gen(inputs_fw, 4),
                             11)

            hidden_fw[0], cell_fw[0], mem_fw = lstm_fw[0](inputs_fw,
                                                          hidden_fw[0],
                                                          cell_fw[0], mem_fw)
            z_t_fw = gradient_highway_fw(hidden_fw[0], z_t_fw)

            tm_hidden_fw[0][t] = z_t_fw
            tm_mem_fw[0][t] = mem_fw
        # Layer 1 Backward
        with tf.variable_scope('bi_cslstm_l1', reuse=tf.AUTO_REUSE):
            # accroding mask_true replace with random noise
            inputs_bw = mask_true[:, seq_length - t - 1] * images_bw[:, t] + (
                1 - mask_true[:, seq_length - t - 1]) * sample_Z(
                    (1 - mask_true[:, seq_length - t - 1]))

            hidden_bw[0], cell_bw[0], mem_bw = lstm_bw[0](inputs_bw,
                                                          hidden_bw[0],
                                                          cell_bw[0], mem_bw)
            z_t_bw = gradient_highway_bw(hidden_bw[0], z_t_bw)

            tm_hidden_bw[0][t] = z_t_bw
            tm_mem_bw[0][t] = mem_bw

    # Layer 2 only have 5 lstm
    hiddenConcatConv_l2 = [None for i in range(seq_length // 2)]
    memConcatConv_l2 = [None for i in range(seq_length // 2)]
    for t in range(seq_length // 2):
        print("Layer 2")
        print("t:{}".format(t))

        # Merge forward and backward output from layer 1
        with tf.variable_scope('merge_l2', reuse=tf.AUTO_REUSE):
            if t < (seq_length // 2 // 2):
                hiddenConcat = tf.concat([
                    tm_hidden_fw[0][t * 2],
                    tm_hidden_bw[0][(seq_length // 2 - t - 1) * 2]
                ],
                                         axis=-1)
                hiddenConcatConv_l2[t] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_fw[0][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='F_h_merge_l2')
                memConcat = tf.concat([
                    tm_mem_fw[0][t * 2],
                    tm_mem_bw[0][(seq_length // 2 - t - 1) * 2]
                ],
                                      axis=-1)
                memConcatConv_l2[t] = tf.layers.conv2d(
                    inputs=memConcat,
                    filters=tm_mem_fw[0][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='F_m_merge_l2')
            else:
                hiddenConcat = tf.concat([
                    tm_hidden_bw[0][(seq_length // 2 - t - 1) * 2],
                    tm_hidden_fw[0][t * 2]
                ],
                                         axis=-1)
                hiddenConcatConv_l2[t] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_fw[0][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='B_h_merge_l2')
                memConcat = tf.concat([
                    tm_mem_bw[0][(seq_length // 2 - t - 1) * 2],
                    tm_mem_fw[0][t * 2]
                ],
                                      axis=-1)
                memConcatConv_l2[t] = tf.layers.conv2d(
                    inputs=memConcat,
                    filters=tm_mem_fw[0][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='B_m_merge_l2')

    for t in range(seq_length // 2):
        # Layer 2 Forward
        with tf.variable_scope('bi_cslstm_l2', reuse=tf.AUTO_REUSE):
            hidden_fw[1], cell_fw[1], mem_fw = lstm_fw[1](
                hiddenConcatConv_l2[t], hidden_fw[1], cell_fw[1],
                memConcatConv_l2[t])

            tm_hidden_fw[1][t] = hidden_fw[1]
            tm_mem_fw[1][t] = mem_fw
        # Layer 2 Backward
        with tf.variable_scope('bi_cslstm_l2', reuse=tf.AUTO_REUSE):
            hidden_bw[1], cell_bw[1], mem_bw = lstm_bw[1](
                hiddenConcatConv_l2[seq_length // 2 - t - 1], hidden_bw[1],
                cell_bw[1], memConcatConv_l2[seq_length // 2 - t - 1])
            tm_hidden_bw[1][t] = hidden_bw[1]
            tm_mem_bw[1][t] = mem_bw

    # Layer 3 only have 5 lstm
    hiddenConcatConv_l3 = [None for i in range(seq_length // 2)]
    memConcatConv_l3 = [None for i in range(seq_length // 2)]
    for t in range(seq_length // 2):
        print("Layer 3")
        print("t:{}".format(t))

        # Merge forward and backward output from layer 2
        with tf.variable_scope('merge_l3', reuse=tf.AUTO_REUSE):
            if t < (seq_length // 2 // 2):
                hiddenConcat = tf.concat([
                    tm_hidden_fw[1][t],
                    tm_hidden_bw[1][seq_length // 2 - t - 1]
                ],
                                         axis=-1)
                hiddenConcatConv_l3[t] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_fw[1][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='F_h_merge_l3')
                memConcat = tf.concat(
                    [tm_mem_fw[1][t], tm_mem_bw[1][seq_length // 2 - t - 1]],
                    axis=-1)
                memConcatConv_l3[t] = tf.layers.conv2d(
                    inputs=memConcat,
                    filters=tm_mem_fw[1][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='F_m_merge_l3')
            else:
                hiddenConcat = tf.concat([
                    tm_hidden_bw[1][seq_length // 2 - t - 1],
                    tm_hidden_fw[1][t]
                ],
                                         axis=-1)
                hiddenConcatConv_l3[t] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_fw[1][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='B_h_merge_l3')
                memConcat = tf.concat(
                    [tm_mem_bw[1][seq_length // 2 - t - 1], tm_mem_fw[1][t]],
                    axis=-1)
                memConcatConv_l3[t] = tf.layers.conv2d(
                    inputs=memConcat,
                    filters=tm_mem_fw[1][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='B_m_merge_l3')

    for t in range(seq_length // 2):
        # Layer 3 Forward
        with tf.variable_scope('bi_cslstm_l3', reuse=tf.AUTO_REUSE):
            hidden_fw[2], cell_fw[2], mem_fw = lstm_fw[2](
                hiddenConcatConv_l3[t], hidden_fw[2], cell_fw[2],
                memConcatConv_l3[t])

            tm_hidden_fw[2][t] = hidden_fw[2]
            tm_mem_fw[2][t] = mem_fw
        # Layer 3 Backward
        with tf.variable_scope('bi_cslstm_l3', reuse=tf.AUTO_REUSE):
            hidden_bw[2], cell_bw[2], mem_bw = lstm_bw[2](
                hiddenConcatConv_l3[seq_length // 2 - t - 1], hidden_bw[2],
                cell_bw[2], memConcatConv_l3[seq_length // 2 - t - 1])
            tm_hidden_bw[2][t] = hidden_bw[2]
            tm_mem_bw[2][t] = mem_bw

    # Layer 4 only have 5 lstm
    hiddenConcatConv_l4 = [None for i in range(seq_length // 2)]
    memConcatConv_l4 = [None for i in range(seq_length // 2)]
    for t in range(seq_length // 2):
        print("Layer 4")
        print("t:{}".format(t))

        # Merge forward and backward output from layer 3
        with tf.variable_scope('merge_l4', reuse=tf.AUTO_REUSE):
            if t < (seq_length // 2 // 2):
                hiddenConcat = tf.concat([
                    tm_hidden_fw[2][t],
                    tm_hidden_bw[2][seq_length // 2 - t - 1]
                ],
                                         axis=-1)
                hiddenConcatConv_l4[t] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_fw[2][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='F_h_merge_l4')
                memConcat = tf.concat(
                    [tm_mem_fw[2][t], tm_mem_bw[2][seq_length // 2 - t - 1]],
                    axis=-1)
                memConcatConv_l4[t] = tf.layers.conv2d(
                    inputs=memConcat,
                    filters=tm_mem_fw[2][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='F_m_merge_l4')
            else:
                hiddenConcat = tf.concat([
                    tm_hidden_bw[2][seq_length // 2 - t - 1],
                    tm_hidden_fw[2][t]
                ],
                                         axis=-1)
                hiddenConcatConv_l4[t] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_fw[2][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='B_h_merge_l4')
                memConcat = tf.concat(
                    [tm_mem_bw[2][seq_length // 2 - t - 1], tm_mem_fw[2][t]],
                    axis=-1)
                memConcatConv_l4[t] = tf.layers.conv2d(
                    inputs=memConcat,
                    filters=tm_mem_fw[2][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='B_m_merge_l4')

    for t in range(seq_length // 2):
        # Layer 4 Forward
        with tf.variable_scope('bi_cslstm_l4', reuse=tf.AUTO_REUSE):
            hidden_fw[3], cell_fw[3], mem_fw = lstm_fw[3](
                hiddenConcatConv_l4[t], hidden_fw[3], cell_fw[3],
                memConcatConv_l4[t])

            tm_hidden_fw[3][t] = hidden_fw[3]
            tm_mem_fw[3][t] = mem_fw
        # Layer 4 Backward
        with tf.variable_scope('bi_cslstm_l4', reuse=tf.AUTO_REUSE):
            hidden_bw[3], cell_bw[3], mem_bw = lstm_bw[3](
                hiddenConcatConv_l4[seq_length // 2 - t - 1], hidden_bw[3],
                cell_bw[3], memConcatConv_l4[seq_length // 2 - t - 1])
            tm_hidden_bw[3][t] = hidden_bw[3]
            tm_mem_bw[3][t] = mem_bw

    # generate output image
    hiddenConcatConv = [None for i in range(seq_length // 2)]
    x_gen = [None for i in range(seq_length // 2)]
    for t in range(seq_length // 2):
        with tf.variable_scope('bi_merge', reuse=tf.AUTO_REUSE):
            if t < (seq_length // 2 // 2):
                hiddenConcat = tf.concat([
                    tm_hidden_fw[3][t],
                    tm_hidden_bw[3][seq_length // 2 - t - 1]
                ],
                                         axis=-1)
                hiddenConcatConv[t] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_fw[3][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='F_h_merge')
            else:
                hiddenConcat = tf.concat([
                    tm_hidden_bw[3][seq_length // 2 - t - 1],
                    tm_hidden_fw[3][t]
                ],
                                         axis=-1)
                hiddenConcatConv[t] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_bw[3][t].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='B_h_merge')
    for t in range(seq_length // 2):
        with tf.variable_scope('generate', reuse=tf.AUTO_REUSE):
            x_gen[t] = tf.layers.conv2d(inputs=hiddenConcatConv[t],
                                        filters=output_channels,
                                        kernel_size=1,
                                        strides=1,
                                        padding='same',
                                        name='bi_back_to_pixel')
            gen_images.append(x_gen[t])
            print("generate t: %d" % t)
            tf.summary.image('x_gen', reshape_patch_back_gen(x_gen[t], 4),
                             seq_length)

    gen_images = tf.stack(gen_images)
    # [batch_size, seq_length, height, width, channels]
    gen_images = tf.transpose(gen_images, [1, 0, 2, 3, 4])

    gt_images = [images[:, i * 2 + 1] for i in range(seq_length // 2)]
    gt_images = tf.stack(gt_images)
    gt_images = tf.transpose(gt_images, [1, 0, 2, 3, 4])

    l2Loss = tf.nn.l2_loss(gen_images - gt_images)
    l1Loss = tf.losses.absolute_difference(gen_images, gt_images)
    gdlLoss = cal_gdl(gen_images, gt_images)
    loss = l2Loss + l1Loss + gdlLoss

    tf.summary.scalar('l2_Loss', l2Loss)
    tf.summary.scalar('l1_Loss', l1Loss)
    tf.summary.scalar('gdl_Loss', gdlLoss)
    tf.summary.scalar('loss', loss)

    return [gen_images, loss, images, images_bw]
Exemplo n.º 4
0
def rnn(images,
        mask_true,
        num_layers,
        num_hidden,
        filter_size,
        stride=1,
        seq_length=20,
        input_length=10,
        tln=True,
        batch_size=None):

    gen_images = []
    lstm = []
    cell = []
    hidden = []
    shape = images.get_shape().as_list()
    output_channels = shape[-1]

    for i in range(num_layers):
        if i == 0:
            num_hidden_in = num_hidden[num_layers - 1]
        else:
            num_hidden_in = num_hidden[i - 1]
        new_cell = cslstm('lstm_' + str(i + 1),
                          filter_size,
                          num_hidden_in,
                          num_hidden[i],
                          shape,
                          tln=tln,
                          batch_size=batch_size)
        lstm.append(new_cell)
        cell.append(None)
        hidden.append(None)

    gradient_highway = ghu('highway', filter_size, num_hidden[0], tln=tln)

    mem = None
    z_t = None

    for t in range(seq_length - 1):
        reuse = bool(gen_images)
        with tf.variable_scope('predrnn_pp', reuse=reuse):
            if t < input_length:
                inputs = images[:, t, ...]
            else:
                inputs = mask_true[:, t - input_length,
                                   ...] * images[:, t, ...] + (
                                       1 - mask_true[:, t - input_length,
                                                     ...]) * x_gen

            hidden[0], cell[0], mem = lstm[0](inputs, hidden[0], cell[0], mem)
            z_t = gradient_highway(hidden[0], z_t, batch_size)
            hidden[1], cell[1], mem = lstm[1](z_t, hidden[1], cell[1], mem)

            # The output hidden here is the results of tanh * tahnh, which falls into the range of [-1, 1]
            for i in range(2, num_layers):
                hidden[i], cell[i], mem = lstm[i](hidden[i - 1], hidden[i],
                                                  cell[i], mem)

            x_gen = tf.layers.conv2d(
                inputs=hidden[num_layers - 1],
                filters=output_channels,
                kernel_size=1,
                strides=1,
                # activation=tf.nn.tanh,
                padding='same',
                name="back_to_pixel")

            # squash
            x_gen = tf.reshape(x_gen, [
                -1, FLAGS.img_height, FLAGS.img_width,
                FLAGS.patch_size_height * FLAGS.patch_size_width,
                FLAGS.img_channel
            ])
            x_gen = squash(x_gen, dim=-1)  # makes a unit vector

            x_gen = tf.reshape(x_gen, [
                -1, FLAGS.img_height, FLAGS.img_width,
                FLAGS.patch_size_height * FLAGS.patch_size_width *
                FLAGS.img_channel
            ])
            gen_images.append(x_gen)

    gen_images = tf.stack(gen_images, axis=1)

    gt_images = images[:, 1:]
    gt_images = tf.reshape(gt_images, [
        -1, FLAGS.seq_length - 1, FLAGS.img_height, FLAGS.img_width,
        FLAGS.patch_size_height * FLAGS.patch_size_width, FLAGS.img_channel
    ])
    gen_images = tf.reshape(gen_images, [
        -1, FLAGS.seq_length - 1, FLAGS.img_height, FLAGS.img_width,
        FLAGS.patch_size_height * FLAGS.patch_size_width, FLAGS.img_channel
    ])
    # loss on the magnitude of speed
    gt_speed = tf.sqrt(gt_images[..., 0]**2 + gt_images[..., 1]**2)
    gen_speed = tf.sqrt(gen_images[..., 0]**2 + gen_images[..., 1]**2)
    if FLAGS.loss_nan == 'nan':
        loss = masked_mse_tf(gen_speed, gt_speed, null_val=np.nan)
        loss += masked_mse_tf(gen_images, gt_images, null_val=np.nan)
    else:
        loss = masked_mse_tf(gen_speed, gt_speed, null_val=0.0)
        loss += masked_mse_tf(gen_images, gt_images, null_val=0.0)

    return [gen_images, loss]
Exemplo n.º 5
0
def rnn_inference(images,
                  num_layers,
                  num_hidden,
                  filter_size,
                  stride=1,
                  pred_length=11,
                  input_length=1,
                  tln=True):

    lstm = []
    cell = []
    hidden = []
    shape = images.get_shape().as_list()
    output_channels = shape[-1]
    # input_length = tf.shape(images)[1]

    for i in range(num_layers):
        if i == 0:
            num_hidden_in = num_hidden[num_layers - 1]
        else:
            num_hidden_in = num_hidden[i - 1]
        new_cell = cslstm('lstm_' + str(i + 1),
                          filter_size,
                          num_hidden_in,
                          num_hidden[i],
                          shape,
                          tln=tln)
        lstm.append(new_cell)

        with tf.variable_scope('states_layer%d' % i, reuse=False) as scope:
            try:
                c = tf.get_variable('c', trainable=False)
                cell.append(c)
            except ValueError:
                cell.append(None)
            try:
                h = tf.get_variable('h', trainable=False)
                hidden.append(h)
            except ValueError:
                hidden.append(None)

    gradient_highway = ghu('highway', filter_size, num_hidden[0], tln=tln)

    with tf.variable_scope('states_global', reuse=False) as scope:
        try:
            mem = [tf.get_variable('mem', trainable=False)]
        except ValueError:
            mem = [None]
        try:
            z_t = [tf.get_variable('z_t', trainable=False)]
        except ValueError:
            z_t = [None]
    x_gen = [None]

    def step_forward(inputs):
        with tf.variable_scope('predrnn_pp', reuse=tf.AUTO_REUSE):
            hidden[0], cell[0], mem[0] = lstm[0](inputs, hidden[0], cell[0],
                                                 mem[0])
            z_t[0] = gradient_highway(hidden[0], z_t[0])
            hidden[1], cell[1], mem[0] = lstm[1](z_t[0], hidden[1], cell[1],
                                                 mem[0])

            for i in range(2, num_layers):
                hidden[i], cell[i], mem[0] = lstm[i](hidden[i - 1], hidden[i],
                                                     cell[i], mem[0])

            x_gen[0] = tf.layers.conv2d(inputs=hidden[num_layers - 1],
                                        filters=output_channels,
                                        kernel_size=1,
                                        strides=1,
                                        padding='same',
                                        name="back_to_pixel")

    '''
    t = tf.constant(0)
    cond = lambda t: tf.less(t, input_length)
    def body(t):
        step_forward(images[:, t])
        t += 1
    tf.while_loop(cond, body, [t])
    '''
    for t in range(input_length):
        step_forward(images[:, t])

    for i in range(num_layers):
        with tf.variable_scope('states_layer%d' % i,
                               reuse=tf.AUTO_REUSE) as scope:
            h = tf.get_variable('h', hidden[i].get_shape(), trainable=False)
            h.assign(hidden[i])
            c = tf.get_variable('c', cell[i].get_shape(), trainable=False)
            c.assign(cell[i])
        with tf.variable_scope('states_global', reuse=tf.AUTO_REUSE) as scope:
            m = tf.get_variable('mem', mem[0].get_shape(), trainable=False)
            m.assign(mem[0])
            z = tf.get_variable('z_t', z_t[0].get_shape(), trainable=False)
            z.assign(z_t[0])

    for i in range(pred_length - 1):
        step_forward(x_gen[0])

    return x_gen[0]
def rnn(images,
        images_bw,
        mask_true,
        num_layers,
        num_hidden,
        filter_size,
        stride=1,
        seq_length=20,
        input_length=5,
        tln=True):
    ###'num_hidden', '32,16,16,16', 4 layers

    # inp_images = []
    gen_images = []
    lstm_fw = []
    lstm_bw = []
    lstm_bi = []
    cell_fw = []
    cell_bw = []
    cell_bi = []
    hidden_fw = []
    hidden_bw = []
    hidden_bi = []
    # shapeConcat = tf.concat([images, mask_true],axis=-1)
    # shape = shapeConcat.get_shape().as_list()
    # output_channels = shape[-1]/2
    shape = images.get_shape().as_list()
    output_channels = shape[-1]
    # Time Machine
    tm_hidden_fw = [[None for i in range(seq_length - 2)] for k in range(4)]
    tm_hidden_bw = [[None for i in range(seq_length - 2)] for k in range(4)]
    tm_mem_fw = [[None for i in range(seq_length - 2)] for k in range(4)]
    tm_mem_bw = [[None for i in range(seq_length - 2)] for k in range(4)]
    # loss_gdl = GDL()

    for i in xrange(num_layers):
        if i == 0:
            num_hidden_in = num_hidden[
                num_layers - 1]  ### [4-1]=3 equels [i-1]=-1 equels 64
        else:
            num_hidden_in = num_hidden[i - 1]
        new_cell_fw = bjstlstm('lstm_fw_' + str(i + 1),
                               filter_size,
                               num_hidden_in,
                               num_hidden[i],
                               shape,
                               tln=tln)
        lstm_fw.append(new_cell_fw)
        cell_fw.append(None)
        hidden_fw.append(None)

    for i in xrange(num_layers):
        if i == 0:
            num_hidden_in = num_hidden[
                num_layers - 1]  ### [4-1]=3 equels [i-1]=-1 equels 64
        else:
            num_hidden_in = num_hidden[i - 1]
        new_cell_bw = bjstlstm('lstm_bw_' + str(i + 1),
                               filter_size,
                               num_hidden_in,
                               num_hidden[i],
                               shape,
                               tln=tln)
        lstm_bw.append(new_cell_bw)
        cell_bw.append(None)
        hidden_bw.append(None)

    # for i in xrange(num_layers-2, num_layers-1):
    #     num_hidden_in = num_hidden[i-1]
    #     new_bicell = bjstlstm('blstm_'+str(i+1),
    #                       filter_size,
    #                       num_hidden_in,
    #                       num_hidden[i],
    #                       shape,
    #                       tln=tln)
    #     lstm_bi.append(new_bicell)
    #     cell_bi.append(None)
    #     hidden_bi.append(None)

    gradient_highway_fw = ghu('highway_fw',
                              filter_size,
                              num_hidden[0],
                              tln=tln)
    gradient_highway_bw = ghu('highway_bw',
                              filter_size,
                              num_hidden[0],
                              tln=tln)

    mem_fw = None
    z_t_fw = None
    mem_bw = None
    z_t_bw = None

    for t_layer1 in xrange(seq_length - 2):  ### t_layer1 = time
        with tf.variable_scope('b_jstlstm_l1', reuse=tf.AUTO_REUSE):
            inputs_fw = mask_true[:, t_layer1] * images[:, t_layer1] + (
                1 - mask_true[:, t_layer1]) * sample_Z(
                    (1 - mask_true[:, t_layer1]))
            # inputs_fwConcat = tf.concat([inputs_fw, mask_true[:,t_layer1]],axis=-1)
            tf.summary.image('masktrue_fw',
                             reshape_patch_back_gen(mask_true[:, t_layer1], 4),
                             29)
            tf.summary.image('input_fw', reshape_patch_back_gen(inputs_fw, 4),
                             29)

            hidden_fw[0], cell_fw[0], mem_fw = lstm_fw[0](inputs_fw,
                                                          hidden_fw[0],
                                                          cell_fw[0], mem_fw)

            z_t_fw = gradient_highway_fw(hidden_fw[0], z_t_fw)

            tm_hidden_fw[0][t_layer1] = z_t_fw
            tm_mem_fw[0][t_layer1] = mem_fw

        with tf.variable_scope('b_jstlstm_l1', reuse=tf.AUTO_REUSE):
            inputs_bw = mask_true[:, seq_length - 1 -
                                  t_layer1] * images_bw[:, t_layer1] + (
                                      1 -
                                      mask_true[:, seq_length - 1 - t_layer1]
                                  ) * sample_Z((1 - mask_true[:, seq_length -
                                                              1 - t_layer1]))

            hidden_bw[0], cell_bw[0], mem_bw = lstm_bw[0](inputs_bw,
                                                          hidden_bw[0],
                                                          cell_bw[0], mem_bw)
            z_t_bw = gradient_highway_bw(hidden_bw[0], z_t_bw)

            tm_hidden_bw[0][t_layer1] = z_t_bw
            tm_mem_bw[0][t_layer1] = mem_bw

    hiddenConcatConv_l2 = [None for i in range(seq_length - 4)]
    memConcatConv_l2 = [None for i in range(seq_length - 4)]
    for t_layer2 in xrange(seq_length - 4):  ### t_layer2 = time
        with tf.variable_scope('merge_l2', reuse=tf.AUTO_REUSE):
            if t_layer2 < int((seq_length - 4) / 2):
                hiddenConcat_bw = tf.concat([
                    tm_hidden_bw[0][t_layer2], tm_hidden_fw[0][-1 - t_layer2]
                ],
                                            axis=-1)
                hiddenConcatConv_l2[-1 - t_layer2] = tf.layers.conv2d(
                    inputs=hiddenConcat_bw,
                    filters=tm_hidden_fw[0][t_layer2].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="B_h_merge_l2")
                memConcat_bw = tf.concat(
                    [tm_mem_bw[0][t_layer2], tm_mem_fw[0][-1 - t_layer2]],
                    axis=-1)
                memConcatConv_l2[-1 - t_layer2] = tf.layers.conv2d(
                    inputs=memConcat_bw,
                    filters=tm_mem_fw[0][t_layer2].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="B_m_merge_l2")

                hiddenConcat_fw = tf.concat([
                    tm_hidden_fw[0][t_layer2], tm_hidden_bw[0][-1 - t_layer2]
                ],
                                            axis=-1)
                hiddenConcatConv_l2[t_layer2] = tf.layers.conv2d(
                    inputs=hiddenConcat_fw,
                    filters=tm_hidden_fw[0][t_layer2].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="F_h_merge_l2")
                memConcat_fw = tf.concat(
                    [tm_mem_fw[0][t_layer2], tm_mem_bw[0][-1 - t_layer2]],
                    axis=-1)
                memConcatConv_l2[t_layer2] = tf.layers.conv2d(
                    inputs=memConcat_fw,
                    filters=tm_mem_fw[0][t_layer2].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="F_m_merge_l2")

        with tf.variable_scope('b_jstlstm_l2', reuse=tf.AUTO_REUSE):
            hidden_bw[1], cell_bw[1], mem_bw = lstm_bw[1](
                hiddenConcatConv_l2[-1 - t_layer2], hidden_bw[1], cell_bw[1],
                memConcatConv_l2[-1 - t_layer2])

            tm_hidden_bw[1][t_layer2 + 1] = hidden_bw[1]
            tm_mem_bw[1][t_layer2 + 1] = mem_bw

        with tf.variable_scope('b_jstlstm_l2', reuse=tf.AUTO_REUSE):
            hidden_fw[1], cell_fw[1], mem_fw = lstm_fw[1](
                hiddenConcatConv_l2[t_layer2], hidden_fw[1], cell_fw[1],
                memConcatConv_l2[t_layer2])

            tm_hidden_fw[1][t_layer2 + 1] = hidden_fw[1]
            tm_mem_fw[1][t_layer2 + 1] = mem_fw

    hiddenConcatConv_l3 = [None for i in range(seq_length - 6)]
    memConcatConv_l3 = [None for i in range(seq_length - 6)]
    for t_layer3 in xrange(seq_length - 6):  ### t_layer3 = time
        with tf.variable_scope('merge_l3', reuse=tf.AUTO_REUSE):
            if t_layer3 < int((seq_length - 6) / 2):
                hiddenConcat_fw = tf.concat([
                    tm_hidden_fw[1][(t_layer3 + 1)],
                    tm_hidden_bw[1][-1 - (t_layer3 + 1)]
                ],
                                            axis=-1)
                hiddenConcatConv_l3[t_layer3] = tf.layers.conv2d(
                    inputs=hiddenConcat_fw,
                    filters=tm_hidden_fw[1][(t_layer3 + 1)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="F_h_merge_l3")
                memConcat_fw = tf.concat([
                    tm_mem_fw[1][(t_layer3 + 1)], tm_mem_bw[1][-1 -
                                                               (t_layer3 + 1)]
                ],
                                         axis=-1)
                memConcatConv_l3[t_layer3] = tf.layers.conv2d(
                    inputs=memConcat_fw,
                    filters=tm_mem_fw[1][(t_layer3 + 1)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="F_m_merge_l3")

                hiddenConcat_bw = tf.concat([
                    tm_hidden_bw[1][(t_layer3 + 1)],
                    tm_hidden_fw[1][-1 - (t_layer3 + 1)]
                ],
                                            axis=-1)
                hiddenConcatConv_l3[-1 - t_layer3] = tf.layers.conv2d(
                    inputs=hiddenConcat_bw,
                    filters=tm_hidden_fw[1][(t_layer3 + 1)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="B_h_merge_l3")
                memConcat_bw = tf.concat([
                    tm_mem_bw[1][(t_layer3 + 1)], tm_mem_fw[1][-1 -
                                                               (t_layer3 + 1)]
                ],
                                         axis=-1)
                memConcatConv_l3[-1 - t_layer3] = tf.layers.conv2d(
                    inputs=memConcat_bw,
                    filters=tm_mem_fw[1][(t_layer3 + 1)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="B_m_merge_l3")

        with tf.variable_scope('b_jstlstm_l3', reuse=tf.AUTO_REUSE):
            hidden_fw[2], cell_fw[2], mem_fw = lstm_fw[2](
                hiddenConcatConv_l3[t_layer3], hidden_fw[2], cell_fw[2],
                memConcatConv_l3[t_layer3])

            tm_hidden_fw[2][t_layer3 + 2] = hidden_fw[2]
            tm_mem_fw[2][t_layer3 + 2] = mem_fw

        with tf.variable_scope('b_jstlstm_l3', reuse=tf.AUTO_REUSE):

            hidden_bw[2], cell_bw[2], mem_bw = lstm_bw[2](
                hiddenConcatConv_l3[-1 - t_layer3], hidden_bw[2], cell_bw[2],
                memConcatConv_l3[-1 - t_layer3])

            tm_hidden_bw[2][t_layer3 + 2] = hidden_bw[2]
            tm_mem_bw[2][t_layer3 + 2] = mem_bw

    hiddenConcatConv_l4 = [None for i in range(seq_length - 8)]
    memConcatConv_l4 = [None for i in range(seq_length - 8)]
    hiddenConcatConv = [None for i in range(seq_length - 8)]
    for t_layer4 in xrange(seq_length - 8):  ### t_layer4 = time
        with tf.variable_scope('merge_l4', reuse=tf.AUTO_REUSE):
            if t_layer4 < int((seq_length - 8) / 2):
                hiddenConcat_bw = tf.concat([
                    tm_hidden_bw[2][(t_layer4 + 2)],
                    tm_hidden_fw[2][-1 - (t_layer4 + 2)]
                ],
                                            axis=-1)
                hiddenConcatConv_l4[-1 - t_layer4] = tf.layers.conv2d(
                    inputs=hiddenConcat_bw,
                    filters=tm_hidden_fw[2][(t_layer4 + 2)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="B_h_merge_l4")
                memConcat_bw = tf.concat([
                    tm_mem_bw[2][(t_layer4 + 2)], tm_mem_fw[2][-1 -
                                                               (t_layer4 + 2)]
                ],
                                         axis=-1)
                memConcatConv_l4[-1 - t_layer4] = tf.layers.conv2d(
                    inputs=memConcat_bw,
                    filters=tm_mem_fw[2][(t_layer4 + 2)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="B_m_merge_l4")

                hiddenConcat_fw = tf.concat([
                    tm_hidden_fw[2][(t_layer4 + 2)],
                    tm_hidden_bw[2][-1 - (t_layer4 + 2)]
                ],
                                            axis=-1)
                hiddenConcatConv_l4[t_layer4] = tf.layers.conv2d(
                    inputs=hiddenConcat_fw,
                    filters=tm_hidden_fw[2][(t_layer4 + 2)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="F_h_merge_l4")
                memConcat_fw = tf.concat([
                    tm_mem_fw[2][(t_layer4 + 2)], tm_mem_bw[2][-1 -
                                                               (t_layer4 + 2)]
                ],
                                         axis=-1)
                memConcatConv_l4[t_layer4] = tf.layers.conv2d(
                    inputs=memConcat_fw,
                    filters=tm_mem_fw[2][(t_layer4 + 2)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="F_m_merge_l4")

        with tf.variable_scope('b_jstlstm_l4', reuse=tf.AUTO_REUSE):
            hidden_bw[3], cell_bw[3], mem_bw = lstm_bw[3](
                hiddenConcatConv_l4[-1 - t_layer4], hidden_bw[3], cell_bw[3],
                memConcatConv_l4[-1 - t_layer4])

            tm_hidden_bw[3][t_layer4 + 3] = hidden_bw[3]
            tm_mem_bw[3][t_layer4 + 3] = mem_bw

        with tf.variable_scope('b_jstlstm_l4', reuse=tf.AUTO_REUSE):
            hidden_fw[3], cell_fw[3], mem_fw = lstm_fw[3](
                hiddenConcatConv_l4[t_layer4], hidden_fw[3], cell_fw[3],
                memConcatConv_l4[t_layer4])

            tm_hidden_fw[3][t_layer4 + 3] = hidden_fw[3]
            tm_mem_fw[3][t_layer4 + 3] = mem_fw

    x_gen = [None for i in range(seq_length - 8)]
    for t_bi in xrange(seq_length - 8):  ### t_bi = time
        with tf.variable_scope('bi_merge', reuse=tf.AUTO_REUSE):
            if t_bi < int((seq_length - 8) / 2):
                hiddenConcat = tf.concat([
                    tm_hidden_fw[3][(t_bi + 3)], tm_hidden_bw[3][-1 -
                                                                 (t_bi + 3)]
                ],
                                         axis=-1)
                hiddenConcatConv[t_bi] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_fw[3][(t_bi + 3)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="F_h_merge")
                hiddenConcat = tf.concat([
                    tm_hidden_bw[3][(t_bi + 3)], tm_hidden_fw[3][-1 -
                                                                 (t_bi + 3)]
                ],
                                         axis=-1)
                hiddenConcatConv[-1 - t_bi] = tf.layers.conv2d(
                    inputs=hiddenConcat,
                    filters=tm_hidden_fw[3][(t_bi + 3)].get_shape()[-1],
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="B_h_merge")

                x_gen[t_bi] = tf.layers.conv2d(inputs=hiddenConcatConv[t_bi],
                                               filters=output_channels,
                                               kernel_size=1,
                                               strides=1,
                                               padding='same',
                                               name="bi_back_to_pixel")
                x_gen[-1 - t_bi] = tf.layers.conv2d(
                    inputs=hiddenConcatConv[-1 - t_bi],
                    filters=output_channels,
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name="bi_back_to_pixel")
            gen_images.append(x_gen[t_bi])
            print("t_bi: %d" % t_bi)
            tf.summary.image('x_gen', reshape_patch_back_gen(x_gen[t_bi], 4),
                             seq_length - 2)

    # inp_images = tf.stack(inp_images)
    gen_images = tf.stack(gen_images)
    # # # gen_images_bw = tf.stack(gen_images_bw)

    # [batch_size, seq_length, height, width, channels]
    # inp_images = tf.transpose(inp_images, [1,0,2,3,4])
    gen_images = tf.transpose(gen_images, [1, 0, 2, 3, 4])
    # # # gen_images_bw = tf.transpose(gen_images_bw, [1,0,2,3,4])

    # No 0, 29, from 1 to seq_length-2
    l2Loss = tf.nn.l2_loss(gen_images - images[:, 4:-4])
    l1Loss = tf.losses.absolute_difference(gen_images, images[:, 4:-4])
    #hbloss = tf.losses.huber_loss(images[:,4:-4], gen_images, delta=1.5)
    gdlLoss = cal_gdl(gen_images, images[:, 4:-4])
    loss = l2Loss + l1Loss + gdlLoss
    tf.summary.scalar('l2_Loss', l2Loss)
    tf.summary.scalar('l1_Loss', l1Loss)
    #tf.summary.scalar('huber_Loss', hbloss)
    tf.summary.scalar('gdl_Loss', gdlLoss)
    tf.summary.scalar('loss', loss)

    return [gen_images, loss, images, images_bw]