Ejemplo n.º 1
0
def network(inputs, hidden, lstm=True):
    conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1")
    # conv2
    conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
    # conv3
    conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3")
    # conv4
    conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4")
    y_0 = conv4
    if lstm:
        # conv lstm cell
        with tf.variable_scope('conv_lstm',
                               initializer=tf.random_uniform_initializer(
                                   -.01, 0.1)):
            cell = BasicConvLSTMCell.BasicConvLSTMCell([8, 8], [3, 3], 4)
            if hidden is None:
                hidden = cell.zero_state(FLAGS.batch_size, tf.float32)
            y_1, hidden = cell(y_0, hidden)
    else:
        y_1 = ld.conv_layer(y_0, 3, 1, 8, "encode_3")

    # conv5
    conv5 = ld.transpose_conv_layer(y_1, 1, 1, 8, "decode_5")
    # conv6
    conv6 = ld.transpose_conv_layer(conv5, 3, 2, 8, "decode_6")
    # conv7
    conv7 = ld.transpose_conv_layer(conv6, 3, 1, 8, "decode_7")
    # x_1
    x_1 = ld.transpose_conv_layer(conv7, 3, 2, 3, "decode_8",
                                  True)  # set activation to linear

    return x_1, hidden
Ejemplo n.º 2
0
        def conv_lstm_cell(inputs, hidden):

            # some convolutional layers before the convLSTM cell
            conv1 = ld.conv_layer(inputs, 3, 2, 4, "encode_1")
            conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
            conv3 = ld.conv_layer(conv2, 3, 1, 16, "encode_3")

            # take output from first conv layers as input to convLSTM cell
            with tf.variable_scope('conv_lstm',
                                   initializer=tf.random_uniform_initializer(
                                       -.01, 0.1)):
                cell = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3],
                                                           16)
                if hidden is None:
                    hidden = cell.zero_state(FLAGS.batch_size, tf.float32)
                cell_output, hidden = cell(conv3, hidden)

            # some convolutional layers after the convLSTM cell
            conv5 = ld.transpose_conv_layer(cell_output, 1, 1, 16, "decode_5")
            conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6")
            conv7 = ld.transpose_conv_layer(conv6, 3, 1, 4, "decode_7")

            # the last convolutional layer will use linear activations
            x_1 = ld.transpose_conv_layer(conv7, 3, 2, 1, "decode_8", True)

            # return the output of the last conv layer, & the hidden cell state
            return x_1, hidden
Ejemplo n.º 3
0
def network_gate(inputs, gating_num, moe, activation='softmax'):
    ch = inputs.get_shape()[-1]
    H = inputs.get_shape()[1]
    W = inputs.get_shape()[2]
    conv1 = ld.conv_layer(inputs, 4, 2, ch * 4, 6, linear=False)
    conv2 = ld.conv_layer(conv1, 4, 2, ch * 8, 7, linear=False)
    conv3 = ld.conv_layer(conv2, 4, 2, ch * 16, 8, linear=False)
    deconv1 = ld.transpose_conv_layer(conv3, 4, 2, ch * 8, 5, linear=False)
    deconv1 = tf.concat([deconv1, conv2], axis=-1)
    deconv2 = ld.transpose_conv_layer(deconv1, 4, 2, 16, 6, linear=False)
    deconv2 = tf.concat([deconv2, conv1], axis=-1)
    deconv3 = ld.transpose_conv_layer(deconv2, 4, 2, 8, 7, linear=False)
    final = ld.conv_layer(deconv3, 4, 1, gating_num, 10, linear=True)
    final = tf.nn.relu(final)
    sq1 = tf.nn.avg_pool(final,
                         ksize=[1, H, W, 1],
                         strides=[1, 1, 1, 1],
                         padding="VALID")
    sq2 = tf.layers.dense(sq1, units=gating_num // 2, activation=tf.nn.relu)
    if moe == 'moe0':
        sq3 = tf.layers.dense(sq2, units=gating_num)
        sq3 = tf.nn.softmax(sq3, axis=-1)
        weight = sq3[:, 0, 0, :]
    else:
        sq3 = tf.layers.dense(sq2, units=gating_num, activation=tf.nn.sigmoid)
        excitation = tf.reshape(sq3, [-1, 1, 1, gating_num])
        weight = final * excitation
        if activation == 'softmax':
            weight = tf.nn.softmax(weight, axis=-1)
    return weight
Ejemplo n.º 4
0
def RNN(x):
    x_dropout = tf.nn.dropout(x, keep_prob)

    x_unwrap = []
    # create network
    with tf.variable_scope('conv_lstm',
                           initializer=tf.random_uniform_initializer(
                               -.01, 0.1)):
        cell = BasicConvLSTMCell.BasicConvLSTMCell([FLAGS.hight, FLAGS.width],
                                                   [3, 3], 1)
        new_state = cell.zero_state(FLAGS.batch_size, tf.float32)

    # conv network
    for i in xrange(FLAGS.seq_length - 1):
        # conv1
        if i < FLAGS.seq_start:  #inputs, kernel_size, stride, num_features, idx
            #conv1 = ld.conv_layer(x_dropout[:,i,:,:,:], 3, 2, 8, "encode_1")
            conv1 = ld.transpose_conv_layer(x_dropout[:, i, :, :, :], 3, 1, 1,
                                            "decode_1")
        else:
            conv1 = ld.transpose_conv_layer(x_1, 3, 1, 1, "decode_1")
        y_0 = conv1
        # conv lstm cell
        y_1, new_state = cell(y_0, new_state)
        # x_1
        x_1 = ld.conv_layer(y_1, 3, 1, 1, "encode_1", True)
        if i >= FLAGS.seq_start:
            x_unwrap.append(x_1)
        # set reuse to true after first go
        if i == 0:
            tf.get_variable_scope().reuse_variables()
    # pack them all together
    x_unwrap = tf.stack(x_unwrap)
    x_unwrap = tf.transpose(x_unwrap, [1, 0, 2, 3, 4])
    return x_unwrap
	def forward():
            """Make forward pass """
            for frame_id in xrange(self.length):
                input_ = tf.concat(3,[self.input_frames[:,frame_id,:,:,:],self.input_frames_low_scale[:,frame_id,:,:,:]])
		for lstm_layer_id in xrange(self.layer_num_lstm):
                    input_,lstm_encode_state[lstm_layer_id]=lstm_encode[lstm_layer_id](input_,lstm_encode_state[lstm_layer_id])

            for i in xrange(self.layer_num_lstm):
                lstm_predict_state[i]=lstm_encode_state[i]
                lstm_decode_state[i] =lstm_decode_state[i]
	    predicts = []
            for frame_id in xrange(self.future_seq_length):
		if frame_id ==0:
                    input_ = tf.concat(3,[self.input_frames[:,-1,:,:,:],self.future_frames_low_scale[:,frame_id,:,:,:]])
		else:
		    #input = y_out
	            input_ = tf.concat(3,[y_out,self.future_frames_low_scale[:,frame_id,:,:,:]])
		    #input_ = tf.concat(3,[self.future_frames[:,frame_id-1,:,:,:],self.future_frames_low_scale[:,frame_id,:,:,:]])
                # adding all layer predictions together
                lstm_pyramid = []
                for lstm_layer_id in xrange(self.layer_num_lstm):
                    input_,lstm_predict_state[lstm_layer_id]=lstm_predict[lstm_layer_id](input_,lstm_predict_state[lstm_layer_id])
                    lstm_pyramid.append(input_)
                y_cat = tf.concat(3,lstm_pyramid)
		y_out = ld.transpose_conv_layer(y_cat,1,1,1,"predict")
                predicts.append(y_out)
            # swap axis
            x_unwrap_gen = tf.pack(predicts)
            predicts = tf.transpose(x_unwrap_gen, [1,0,2,3,4])
            
	    decodes_temp = []
            for frame_id in range(self.length,0,-1):
                if frame_id ==self.length:
                    input_ = tf.concat(3,[self.future_frames[:,0,:,:,:],self.input_frames_low_scale[:,frame_id-1,:,:,:]])
                else:
                    input_ = tf.concat(3,[self.input_frames[:,frame_id,:,:,:],self.input_frames_low_scale[:,frame_id-1,:,:,:]])
                # adding all layer predictions together
                lstm_pyramid = []
                for lstm_layer_id in xrange(self.layer_num_lstm):
                    input_,lstm_decode_state[lstm_layer_id]=lstm_decode[lstm_layer_id](input_,lstm_decode_state[lstm_layer_id])
                    lstm_pyramid.append(input_)
                y_cat = tf.concat(3,lstm_pyramid)
                y_out = ld.transpose_conv_layer(y_cat,1,1,1,"decode")
                decodes_temp.append(y_out)

	    decodes =[]
            for i in range(self.length):
      	        decodes.append(decodes_temp.pop())
            # swap axis
            x_unwrap_de = tf.pack(decodes)
            decodes = tf.transpose(x_unwrap_de, [1,0,2,3,4])
	    
	    return predicts, decodes
def conv_model(x, keep_prob):
    # create network
    # encodeing part first
    # conv1
    conv1 = ld.conv_layer(x, 3, 2, 8, "encode_1")
    # conv2
    conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
    # conv3
    conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3")
    # conv4
    conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4")
    # fc5
    fc5 = ld.fc_layer(conv4, 128, "encode_5", True, False)
    # dropout maybe
    fc5_dropout = tf.nn.dropout(fc5, keep_prob)
    # y
    y = ld.fc_layer(fc5_dropout, (FLAGS.hidden_size) * 2, "encode_6", False,
                    True)
    mean, stddev = tf.split(1, 2, y)
    stddev = tf.sqrt(tf.exp(stddev))
    # now decoding part
    # sample distrobution
    epsilon = tf.random_normal(mean.get_shape())
    y_sampled = mean + epsilon * stddev
    # fc7
    fc7 = ld.fc_layer(y_sampled, 128, "decode_7", False, False)
    # fc8
    fc8 = ld.fc_layer(fc7, 4 * 8 * 8, "decode_8", False, False)
    conv9 = tf.reshape(fc8, [-1, 8, 8, 4])
    # conv10
    conv10 = ld.transpose_conv_layer(conv9, 1, 1, 8, "decode_9")
    # conv11
    conv11 = ld.transpose_conv_layer(conv10, 3, 2, 8, "decode_10")
    # conv12
    conv12 = ld.transpose_conv_layer(conv11, 3, 1, 8, "decode_11")
    # conv13
    conv13 = ld.transpose_conv_layer(conv12, 3, 2, 1, "decode_12", True)
    # x_prime
    x_prime = conv13
    x_prime = tf.nn.sigmoid(x_prime)

    return mean, stddev, y_sampled, x_prime
Ejemplo n.º 7
0
def network_rot(inputs):

    conv1 = ld.conv_layer(inputs, 4, 2, 16, 1, linear=False)

    conv2 = ld.conv_layer(conv1, 4, 2, 32, 2, linear=False)

    conv3 = ld.conv_layer(conv2, 4, 2, 64, 3, linear=False)

    conv4 = ld.conv_layer(conv3, 4, 2, 128, 4, linear=False)

    deconv1 = ld.transpose_conv_layer(conv4, 4, 2, 64, 1, linear=False)
    deconv1 = tf.concat([deconv1, conv3], axis=-1)

    deconv2 = ld.transpose_conv_layer(deconv1, 4, 2, 32, 2, linear=False)
    deconv2 = tf.concat([deconv2, conv2], axis=-1)

    deconv3 = ld.transpose_conv_layer(deconv2, 4, 2, 16, 3, linear=False)
    deconv3 = tf.concat([deconv3, conv1], axis=-1)

    deconv4 = ld.transpose_conv_layer(deconv3, 4, 2, 8, 4, linear=False)

    flow = ld.conv_layer(deconv4, 4, 1, 2, 5, linear=True)

    return flow
Ejemplo n.º 8
0
def network(inputs, hidden):
  conv1 = ld.conv2d(inputs, (3,3), (1,1), 1, "encode_1")
  # conv2
  # conv2 = ld.conv2d(conv1, (3,3), (1,2), 8, "encode_2")
  # conv3
  # conv3 = ld.conv2d(conv2, (3,3), (1,2), 8, "encode_3")
  # conv4
  #conv4 = ld.conv2d(conv3, (1,1), (1,1), 4, "encode_4")
  #y_0 = conv4
  y_0 = conv1
  # conv lstm cell 
  with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
    cell = BasicConvLSTMCell.BasicConvLSTMCell([16,128], [8,8], 16)
    if hidden is None:
      hidden = cell.zero_state(FLAGS.batch_size, tf.float32) 
    y_1, hidden = cell(y_0, hidden)
  
  with tf.variable_scope('conv_lstm_2', initializer = tf.random_uniform_initializer(-.01, 0.1)):  
    cell2 = BasicConvLSTMCell.BasicConvLSTMCell([16,128], [8,8], 16)
    if hidden is None:
      hidden = cell2.zero_state(FLAGS.batch_size, tf.float32) 
    y_2, hidden = cell2(y_1, hidden)
    
  with tf.variable_scope('conv_lstm_3', initializer = tf.random_uniform_initializer(-.01, 0.1)):  
    cell3 = BasicConvLSTMCell.BasicConvLSTMCell([16,128], [8,8], 16)
    if hidden is None:
      hidden = cell3.zero_state(FLAGS.batch_size, tf.float32) 
    y_3, hidden = cell3(y_2, hidden)
 
  # conv5
  #conv5 = ld.transpose_conv_layer(y_3, (1,1), (1,1), 8, "decode_5")
  # conv6
  # conv6 = ld.transpose_conv_layer(conv5, (3,3), (1,2), 8, "decode_6")
  # conv7
  # conv7 = ld.transpose_conv_layer(conv6, (3,3), (1,2), 8, "decode_7")
  # x_1 
  conv7 = y_3
  x_1 = ld.transpose_conv_layer(conv7, (3,3), (1,1), 1, "decode_8", True) # set activation to linear

  return x_1, hidden
def all_conv_model(x, keep_prob):
    # create network
    # encodeing part first
    # conv1
    conv1 = ld.conv_layer(x, 3, 2, 8, "encode_1")
    # conv2
    conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
    # conv3
    conv3 = ld.conv_layer(conv2, 3, 2, 16, "encode_3")
    # conv4
    conv4 = ld.conv_layer(conv3, 3, 1, 8, "encode_4")
    # conv5
    conv5 = ld.conv_layer(conv4, 3, 1, 8, "encode_5")
    # conv6
    conv6 = ld.conv_layer(conv5, 3, 2, FLAGS.hidden_size * 2, "encode_6", True)
    mean_conv, stddev_conv = tf.split(3, 2, conv6)
    mean = tf.reshape(mean_conv, [-1, 4 * 4 * FLAGS.hidden_size])
    stddev = tf.reshape(stddev_conv, [-1, 4 * 4 * FLAGS.hidden_size])
    stddev = tf.sqrt(tf.exp(stddev))
    # now decoding part
    # sample distrobution
    epsilon = tf.random_normal(mean.get_shape())
    y_sampled = mean + epsilon * stddev
    y_sampled_conv = tf.reshape(y_sampled, [-1, 4, 4, FLAGS.hidden_size])
    # conv7
    conv7 = ld.transpose_conv_layer(y_sampled_conv, 3, 2, 8, "decode_7")
    # conv8
    conv8 = ld.transpose_conv_layer(conv7, 3, 1, 8, "decode_8")
    # conv9
    conv9 = ld.transpose_conv_layer(conv8, 3, 1, 16, "decode_9")
    # conv10
    conv10 = ld.transpose_conv_layer(conv9, 3, 2, 8, "decode_10")
    # conv11
    conv11 = ld.transpose_conv_layer(conv10, 3, 1, 8, "decode_11")
    # conv12
    conv12 = ld.transpose_conv_layer(conv11, 3, 2, 1, "decode_12", True)
    # x_prime
    x_prime = conv12
    x_prime = tf.nn.sigmoid(x_prime)

    # reshape these just to make them look like other networks
    return mean, stddev, y_sampled, x_prime
Ejemplo n.º 10
0
def network_2d(inputs, encoder_state, past_state, future_state):
    #inputs is 3D tensor (batch, )
    conv = ld.conv2d(inputs, (4, 4), (1, 2), 4, "encode")
    #conv = inputs
    # encoder convlstm
    with tf.variable_scope(
            'conv_lstm_encoder_1',
            initializer=tf.contrib.layers.xavier_initializer(uniform=True)):
        cell1 = BasicConvLSTMCell2d([4, 256], [3, 8], 4, strides=(2, 1))
        if encoder_state is None:
            encoder_state = cell1.zero_state(FLAGS.batch_size, tf.float32)
        conv1, encoder_state = cell1(conv, encoder_state)
    with tf.variable_scope(
            'conv_lstm_encoder_2',
            initializer=tf.contrib.layers.xavier_initializer(uniform=True)):
        cell2 = BasicConvLSTMCell2d([2, 256], [2, 8], 4, strides=(2, 1))
        conv2, encoder_state = cell2(conv1, encoder_state)
    with tf.variable_scope(
            'conv_lstm_encoder_3',
            initializer=tf.contrib.layers.xavier_initializer(uniform=True)):
        cell3 = BasicConvLSTMCell2d([1, 256], [1, 8], 4)
        conv3, encoder_state = cell3(conv2, encoder_state)

    # past decoder convlstm
    with tf.variable_scope(
            'past_decoder_1',
            initializer=tf.contrib.layers.xavier_initializer(uniform=True)):
        pcell1 = BasicConvLSTMCell2d([1, 256], [1, 8], 4)
        if past_state is None:
            past_state = pcell1.zero_state(FLAGS.batch_size, tf.float32)
        pconv1, past_state = pcell1(conv1, past_state)
    with tf.variable_scope(
            'past_decoder_2',
            initializer=tf.contrib.layers.xavier_initializer(uniform=True)):
        pcell2 = BasicConvLSTMCell2d([1, 256], [1, 8], 4)
        pconv2, past_state = pcell2(conv2, past_state)
    with tf.variable_scope(
            'past_decoder_3',
            initializer=tf.contrib.layers.xavier_initializer(uniform=True)):
        pcell3 = BasicConvLSTMCell2d([1, 256], [1, 8], 4)
        pconv3, past_state = pcell3(conv3, past_state)

    # future decoder convlstm
    with tf.variable_scope(
            'future_decoder_1',
            initializer=tf.contrib.layers.xavier_initializer(uniform=True)):
        fcell1 = BasicConvLSTMCell2d([1, 256], [1, 8], 4)
        if future_state is None:
            future_state = fcell1.zero_state(FLAGS.batch_size, tf.float32)
        fconv1, future_state = fcell1(conv1, future_state)
    with tf.variable_scope(
            'future_decoder_2',
            initializer=tf.contrib.layers.xavier_initializer(uniform=True)):
        fcell2 = BasicConvLSTMCell2d([1, 256], [1, 8], 4)
        fconv2, future_state = fcell2(conv2, future_state)
    with tf.variable_scope(
            'future_decoder_3',
            initializer=tf.contrib.layers.xavier_initializer(uniform=True)):
        fcell3 = BasicConvLSTMCell2d([1, 256], [1, 8], 4)
        fconv3, future_state = fcell3(conv3, future_state)
    # present output
    x_1 = ld.transpose_conv_layer(pconv3, (1, 4), (1, 2), 1, "present_output",
                                  True)
    # # future output
    y_1 = ld.transpose_conv_layer(fconv3, (1, 4), (1, 2), 1, "future_output",
                                  True)
    #x_1 = pconv3; y_1 = fconv3
    #import IPython; IPython.embed()
    return x_1, y_1, encoder_state, past_state, future_state
Ejemplo n.º 11
0
    def define_graph(self):
        """ define a Bidirectional LSTM and concatenated feature for MLP
        """
        lstm_encode = []
        lstm_predict = []
        lstm_decode = []
        lstm_encode_state = []
        lstm_predict_state = []
        lstm_decode_state = []
        with tf.name_scope('input'):
            self.input_frames = tf.placeholder(
                tf.float32,
                shape=[None, self.length, self.height, self.width, 1])

            # use variable batch_size for more flexibility
            self.D_label = tf.placeholder(tf.float32,
                                          shape=[self.batch_size, 2])
            self.future_frames = tf.placeholder(tf.float32,
                                                shape=[
                                                    None,
                                                    self.future_seq_length,
                                                    self.height, self.width, 1
                                                ])

            self.input_frames_low_scale = tf.placeholder(
                tf.float32,
                shape=[None, self.length, self.height, self.width, 1])

            self.future_frames_low_scale = tf.placeholder(
                tf.float32,
                shape=[
                    None, self.future_seq_length, self.height, self.width, 1
                ])

        with tf.variable_scope("G_scale_{}".format(self.scale_index)):

            for layer_id_, kernel_, kernel_num_ in zip(
                    xrange(self.layer_num_lstm), self.kernel_size,
                    self.kernel_num):
                layer_name_encode = "conv_lstm_encode_{}".format(layer_id_)
                #with tf.variable_scope('conv_lstm_encode', initializer = tf.random_uniform_initializer(-.01, 0.1)):
                temp_cell = BasicConvLSTMCell.BasicConvLSTMCell(
                    [self.height, self.width], [kernel_, kernel_], kernel_num_,
                    layer_name_encode)
                temp_state = temp_cell.zero_state(self.batch_size, tf.float32)
                lstm_encode.append(temp_cell)
                lstm_encode_state.append(temp_state)

            for layer_id_, kernel_, kernel_num_ in zip(
                    xrange(self.layer_num_lstm), self.kernel_size,
                    self.kernel_num):
                layer_name_predict = "conv_lstm_predict_{}".format(layer_id_)
                #with tf.variable_scope('conv_lstm_predict', initializer = tf.random_uniform_initializer(-.01, 0.1)):
                temp_cell = BasicConvLSTMCell.BasicConvLSTMCell(
                    [self.height, self.width], [kernel_, kernel_], kernel_num_,
                    layer_name_predict)
                temp_state = temp_cell.zero_state(self.batch_size, tf.float32)
                lstm_predict.append(temp_cell)
                lstm_predict_state.append(temp_state)

            for layer_id_, kernel_, kernel_num_ in zip(
                    xrange(self.layer_num_lstm), self.kernel_size,
                    self.kernel_num):
                layer_name_predict = "conv_lstm_decode_{}".format(layer_id_)
                #with tf.variable_scope('conv_lstm_predict', initializer = tf.random_uniform_initializer(-.01, 0.1)):
                temp_cell = BasicConvLSTMCell.BasicConvLSTMCell(
                    [self.height, self.width], [kernel_, kernel_], kernel_num_,
                    layer_name_predict)
                temp_state = temp_cell.zero_state(self.batch_size, tf.float32)
                lstm_decode.append(temp_cell)
                lstm_decode_state.append(temp_state)

        input_ = tf.concat(3, [
            self.input_frames[:, 0, :, :, :],
            self.input_frames_low_scale[:, 0, :, :, :]
        ])
        for lstm_layer_id in xrange(self.layer_num_lstm):
            input_, lstm_encode_state[lstm_layer_id] = lstm_encode[
                lstm_layer_id](input_, lstm_encode_state[lstm_layer_id])

        input_ = tf.concat(3, [
            self.input_frames[:, 0, :, :, :],
            self.input_frames_low_scale[:, 0, :, :, :]
        ])
        lstm_pyramid = []
        for lstm_layer_id in xrange(self.layer_num_lstm):
            input_, lstm_predict_state[lstm_layer_id] = lstm_predict[
                lstm_layer_id](input_, lstm_predict_state[lstm_layer_id])
            lstm_pyramid.append(input_)

        input_ = tf.concat(3, [
            self.input_frames[:, 0, :, :, :],
            self.input_frames_low_scale[:, 0, :, :, :]
        ])
        lstm_pyramid_de = []
        for lstm_layer_id in xrange(self.layer_num_lstm):
            input_, lstm_decode_state[lstm_layer_id] = lstm_decode[
                lstm_layer_id](input_, lstm_decode_state[lstm_layer_id])
            lstm_pyramid_de.append(input_)

        y_cat = tf.concat(3, lstm_pyramid)
        temp = ld.transpose_conv_layer(y_cat, 1, 1, 1, "predict")

        y_cat_de = tf.concat(3, lstm_pyramid_de)
        temp_de = ld.transpose_conv_layer(y_cat_de, 1, 1, 1, "decode")
        self.scope.reuse_variables()

        def forward():
            """Make forward pass """
            for frame_id in xrange(self.length):
                input_ = tf.concat(3, [
                    self.input_frames[:, frame_id, :, :, :],
                    self.input_frames_low_scale[:, frame_id, :, :, :]
                ])
                for lstm_layer_id in xrange(self.layer_num_lstm):
                    input_, lstm_encode_state[lstm_layer_id] = lstm_encode[
                        lstm_layer_id](input_,
                                       lstm_encode_state[lstm_layer_id])

            for i in xrange(self.layer_num_lstm):
                lstm_predict_state[i] = lstm_encode_state[i]
                lstm_decode_state[i] = lstm_decode_state[i]
            predicts = []
            for frame_id in xrange(self.future_seq_length):
                if frame_id == 0:
                    input_ = tf.concat(3, [
                        self.input_frames[:, -1, :, :, :],
                        self.future_frames_low_scale[:, frame_id, :, :, :]
                    ])
                else:
                    input_ = tf.concat(3, [
                        y_out, self.future_frames_low_scale[:,
                                                            frame_id, :, :, :]
                    ])
                    #input_ = tf.concat(3,[self.future_frames[:,frame_id-1,:,:,:],self.future_frames_low_scale[:,frame_id,:,:,:]])
                    # adding all layer predictions together
                lstm_pyramid = []
                for lstm_layer_id in xrange(self.layer_num_lstm):
                    input_, lstm_predict_state[lstm_layer_id] = lstm_predict[
                        lstm_layer_id](input_,
                                       lstm_predict_state[lstm_layer_id])
                    lstm_pyramid.append(input_)
                y_cat = tf.concat(3, lstm_pyramid)
                y_out = ld.transpose_conv_layer(y_cat, 1, 1, 1, "predict")
                predicts.append(y_out)
            # swap axis
            x_unwrap_gen = tf.pack(predicts)
            predicts = tf.transpose(x_unwrap_gen, [1, 0, 2, 3, 4])

            decodes_temp = []
            for frame_id in range(self.length, 0, -1):
                if frame_id == self.length:
                    input_ = tf.concat(3, [
                        self.future_frames[:, 0, :, :, :],
                        self.input_frames_low_scale[:, frame_id - 1, :, :, :]
                    ])
                else:
                    input_ = tf.concat(3, [
                        self.input_frames[:, frame_id, :, :, :],
                        self.input_frames_low_scale[:, frame_id - 1, :, :, :]
                    ])
                # adding all layer predictions together
                lstm_pyramid = []
                for lstm_layer_id in xrange(self.layer_num_lstm):
                    input_, lstm_decode_state[lstm_layer_id] = lstm_decode[
                        lstm_layer_id](input_,
                                       lstm_decode_state[lstm_layer_id])
                    lstm_pyramid.append(input_)
                y_cat = tf.concat(3, lstm_pyramid)
                y_out = ld.transpose_conv_layer(y_cat, 1, 1, 1, "decode")
                decodes_temp.append(y_out)

            decodes = []
            for i in range(self.length):
                decodes.append(decodes_temp.pop())
            # swap axis
            x_unwrap_de = tf.pack(decodes)
            decodes = tf.transpose(x_unwrap_de, [1, 0, 2, 3, 4])

            return predicts, decodes

        self.preds, self.decodes = forward()
        """ loss and training op """
        mean_loss = l2_loss(self.preds,
                            self.future_frames) / self.future_seq_length
        mean_loss_de = l2_loss(self.decodes, self.input_frames) / self.length
        GT_label = tf.concat(
            1, [tf.ones([self.batch_size, 1]),
                tf.zeros([self.batch_size, 1])])
        entropy_loss = adv_loss(self.D_label, GT_label)
        self.loss = mean_loss + entropy_loss + mean_loss_de

        #self.loss = combined_loss(self.preds,self.future_frames,self.D_label)
        temp_op = tf.train.AdamOptimizer(FLAGS.lr)
        variable_collection = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, self.scope_string)
        gvs = temp_op.compute_gradients(self.loss,
                                        var_list=variable_collection)
        capped_gvs = [(tf.clip_by_norm(grad, FLAGS.clip), var)
                      for grad, var in gvs]
        self.train_op = temp_op.apply_gradients(capped_gvs)

        mean_loss_summary = tf.summary.scalar('loss_mean_pre', mean_loss)
        mean_loss_de_summary = tf.summary.scalar('loss_mean_dec', mean_loss_de)
        entropy_loss_summary = tf.summary.scalar('loss_entropy', entropy_loss)
        loss_summary = tf.summary.scalar('loss_G', self.loss)
        self.summary = tf.summary.merge([
            loss_summary, mean_loss_summary, entropy_loss_summary,
            mean_loss_de_summary
        ])
def train():
  """Train ring_net for a number of steps."""
  with tf.Graph().as_default():
    # make inputs
    x = tf.placeholder(tf.float32, [None, FLAGS.seq_length, 64, 64, 3])

    # possible dropout inside
    keep_prob = tf.placeholder("float")
    x_dropout = tf.nn.dropout(x, keep_prob)

    # create network
    x_unwrap = []
    with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
      cell = BasicConvLSTMCell.BasicConvLSTMCell([8,8], [3,3], 4)
      new_state = cell.zero_state(FLAGS.batch_size, tf.float32)

    # conv network
    for i in range(FLAGS.seq_length-1):

      # conv1
      if i < FLAGS.seq_start:
        conv1 = ld.conv_layer(x_dropout[:,i,:,:,:], 3, 2, 8, "encode_1")
      else:
        conv1 = ld.conv_layer(x_1, 3, 2, 8, "encode_1")
      # conv2
      conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
      # conv3
      conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3")
      # conv4
      conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4")
      y_0 = conv4
      # conv lstm cell
      y_1, new_state = cell(y_0, new_state)
      # conv5
      conv5 = ld.transpose_conv_layer(y_1, 1, 1, 8, "decode_5")
      # conv6
      conv6 = ld.transpose_conv_layer(conv5, 3, 2, 8, "decode_6")
      # conv7
      conv7 = ld.transpose_conv_layer(conv6, 3, 1, 8, "decode_7")
      # x_1
      x_1 = ld.transpose_conv_layer(conv7, 3, 2, 3, "decode_8", True) # set activation to linear
      if i >= FLAGS.seq_start:
        x_unwrap.append(x_1)
      # set reuse to true after first go
      if i == 0:
        tf.get_variable_scope().reuse_variables()

    # pack them all together
    x_unwrap = tf.pack(x_unwrap)
    x_unwrap = tf.transpose(x_unwrap, [1,0,2,3,4])

    # this part will be used for generating video
    x_unwrap_gen = []
    new_state_gen = cell.zero_state(FLAGS.batch_size, tf.float32)
    for i in range(50):
      # conv1
      if i < FLAGS.seq_start:
        conv1 = ld.conv_layer(x[:,i,:,:,:], 3, 2, 8, "encode_1")
      else:
        conv1 = ld.conv_layer(x_1_gen, 3, 2, 8, "encode_1")
      # conv2
      conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
      # conv3
      conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3")
      # conv4
      conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4")
      y_0 = conv4
      # conv lstm cell
      y_1, new_state_gen = cell(y_0, new_state_gen)
      # conv5
      conv5 = ld.transpose_conv_layer(y_1, 1, 1, 8, "decode_5")
      # conv6
      conv6 = ld.transpose_conv_layer(conv5, 3, 2, 8, "decode_6")
      # conv7
      conv7 = ld.transpose_conv_layer(conv6, 3, 1, 8, "decode_7")
      # x_1_gen
      x_1_gen = ld.transpose_conv_layer(conv7, 3, 2, 3, "decode_8", True) # set activation to linear
      if i >= FLAGS.seq_start:
        x_unwrap_gen.append(x_1_gen)

    # pack them generated ones
    x_unwrap_gen = tf.pack(x_unwrap_gen)
    x_unwrap_gen = tf.transpose(x_unwrap_gen, [1,0,2,3,4])

    # calc total loss (compare x_t to x_t+1)
    loss = tf.nn.l2_loss(x[:,FLAGS.seq_start+1:,:,:,:] - x_unwrap[:,:,:,:,:])
    tf.scalar_summary('loss', loss)

    # training
    train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss)

    # List of all Variables
    variables = tf.all_variables()

    # Build a saver
    saver = tf.train.Saver(tf.all_variables())

    # Summary op
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session()

    # init if this is the very time training
    print("init network from scratch")
    sess.run(init)

    # Summary op
    graph_def = sess.graph.as_graph_def(add_shapes=True)
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=graph_def)

    for step in range(FLAGS.max_step):
      #dat = generate_bouncing_ball_sample(FLAGS.batch_size, FLAGS.seq_length, 64, FLAGS.num_balls)
      dat = get_data(FLAGS.batch_size, FLAGS.seq_length,(64,64))
      t = time.time()
      _, loss_r = sess.run([train_op, loss],feed_dict={x:dat, keep_prob:FLAGS.keep_prob})
      elapsed = time.time() - t

      if step%100 == 0 and step != 0:
        summary_str = sess.run(summary_op, feed_dict={x:dat, keep_prob:FLAGS.keep_prob})
        summary_writer.add_summary(summary_str, step)
        print("time per batch is " + str(elapsed))
        print(step)
        print(loss_r)

      assert not np.isnan(loss_r), 'Model diverged with loss = NaN'

      if step%1000 == 0:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
        print("saved to " + FLAGS.train_dir)

        # make video
        print("now generating video!")
        video = cv2.VideoWriter()
        success = video.open("generated_conv_lstm_video.mov", fourcc, 4, (180, 180), True)
        dat_gif = dat

        render_original_video(dat_gif)



        ims = sess.run([x_unwrap_gen],feed_dict={x:dat_gif, keep_prob:FLAGS.keep_prob})
        ims = ims[0][0]
        print(ims.shape)
        for i in range(50 - FLAGS.seq_start):
          x_1_r = np.uint8(np.maximum(ims[i,:,:,:], 0) * 255)
          print(x_1_r.shape)
          new_im = cv2.resize(x_1_r, (180,180))
          video.write(new_im)
        video.release()
def train():
    """Train ring_net for a number of steps."""
    with tf.Graph().as_default():
        # make inputs
        x = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3])

        # possible dropout inside
        keep_prob = tf.placeholder("float")

        # create network
        # encodeing part first
        # conv1
        conv1 = ld.conv_layer(x, 3, 2, 8, "encode_1")
        # conv2
        conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
        # conv3
        conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3")
        # conv4
        conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4")
        # fc5
        fc5 = ld.fc_layer(conv4, 128, "encode_5", True, False)
        # dropout maybe
        fc5_dropout = tf.nn.dropout(fc5, keep_prob)
        # y
        y = ld.fc_layer(fc5_dropout, (FLAGS.hidden_size) * 2, "encode_6",
                        False, True)
        mean, stddev = tf.split(1, 2, y)
        stddev = tf.sqrt(tf.exp(stddev))
        # now decoding part
        # sample distrobution
        epsilon = tf.random_normal(mean.get_shape())
        y_sampled = mean + epsilon * stddev
        # fc7
        fc7 = ld.fc_layer(y_sampled, 128, "decode_7", False, False)
        # fc8
        fc8 = ld.fc_layer(fc7, 4 * 8 * 8, "decode_8", False, False)
        conv9 = tf.reshape(fc8, [-1, 8, 8, 4])
        # conv10
        conv10 = ld.transpose_conv_layer(conv9, 1, 1, 8, "decode_9")
        # conv11
        conv11 = ld.transpose_conv_layer(conv10, 3, 2, 8, "decode_10")
        # conv12
        conv12 = ld.transpose_conv_layer(conv11, 3, 1, 8, "decode_11")
        # conv13
        conv13 = ld.transpose_conv_layer(conv12, 3, 2, 3, "decode_12", True)
        # x_prime
        x_prime = conv13
        x_prime = tf.nn.sigmoid(x_prime)

        # now calc loss
        epsilon = 1e-8
        # calc loss from vae
        kl_loss = 0.5 * (tf.square(mean) + tf.square(stddev) -
                         2.0 * tf.log(stddev + epsilon) - 1.0)
        loss_vae = FLAGS.beta * tf.reduce_sum(kl_loss)
        # log loss for reconstruction
        loss_reconstruction = tf.reduce_sum(-x * tf.log(x_prime + epsilon) -
                                            (1.0 - x) *
                                            tf.log(1.0 - x_prime + epsilon))
        # save for tensorboard
        tf.scalar_summary('loss_vae', loss_vae)
        tf.scalar_summary('loss_reconstruction', loss_reconstruction)
        # calc total loss
        loss = tf.reduce_sum(loss_vae + loss_reconstruction)

        # training
        train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss)

        # List of all Variables
        variables = tf.all_variables()

        # Build a saver
        saver = tf.train.Saver(tf.all_variables())

        # Summary op
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session()

        # init if this is the very time training
        print("init network from scratch")
        sess.run(init)

        # Summary op
        graph_def = sess.graph.as_graph_def(add_shapes=True)
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                                graph_def=graph_def)

        for step in xrange(FLAGS.max_step):
            dat = b.bounce_vec(32, FLAGS.num_balls, FLAGS.batch_size)
            t = time.time()
            _, loss_r = sess.run([train_op, loss],
                                 feed_dict={
                                     x: dat,
                                     keep_prob: FLAGS.keep_prob
                                 })
            elapsed = time.time() - t
            #print(elapsed)

            if step % 500 == 0:
                _, loss_vae_r, loss_reconstruction_r, y_sampled_r, x_prime_r, kl_loss_dis, stddev_r = sess.run(
                    [
                        train_op, loss_vae, loss_reconstruction, y_sampled,
                        x_prime, kl_loss, stddev
                    ],
                    feed_dict={
                        x: dat,
                        keep_prob: FLAGS.keep_prob
                    })
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           x: dat,
                                           keep_prob: FLAGS.keep_prob
                                       })
                summary_writer.add_summary(summary_str, step)
                print("loss vae value at " + str(loss_vae_r))
                print("loss reconstruction value at " +
                      str(loss_reconstruction_r))
                print("min sampled vector " + str(np.min(y_sampled_r)))
                print("max sampled vector " + str(np.max(y_sampled_r)))
                print("time per batch is " + str(elapsed))
                cv2.imwrite("real_balls.jpg", np.uint8(dat[0, :, :, :] * 255))
                cv2.imwrite("generated_balls.jpg",
                            np.uint8(x_prime_r[0, :, :, :] * 255))
                kl_loss_dis = np.sort(np.sum(kl_loss_dis, axis=0))
                stddev_r = np.sort(np.sum(stddev_r, axis=0))
                #plt.plot(kl_loss_dis, label="step " + str(step))
                #plt.legend(loc = 'center left')
                #plt.savefig('kl_error_dis.png')
                plt.plot(stddev_r, label="step " + str(step))
                plt.legend(loc='center left')
                plt.savefig('stddev_r.png')

            assert not np.isnan(loss_r), 'Model diverged with loss = NaN'

            if step % 1000 == 0:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                print("saved to " + FLAGS.train_dir)
                print("step " + str(step))
def train():
    """Train ring_net for a number of steps."""
    with tf.Graph().as_default():
        # make inputs
        x = tf.placeholder(tf.float32, [None, FLAGS.seq_length, 32, 32, 3])

        # possible dropout inside
        keep_prob = tf.placeholder("float")
        x_dropout = tf.nn.dropout(x, keep_prob)

        # create network
        x_unwrap = []

        # conv peice in
        for i in range(FLAGS.seq_length - 1):
            if i == 0:
                # conv1
                conv1 = ld.conv_layer(x_dropout[:, i, :, :, :], 3, 2, 8,
                                      "encode_1")
            else:
                # conv1
                conv1 = ld.conv_layer(x_1, 3, 2, 8, "encode_1")
            # conv2
            conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
            # conv3
            conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3")
            # conv4
            conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4")
            # conv5
            conv5 = ld.transpose_conv_layer(conv4, 1, 1, 8, "decode_5")
            # conv6
            conv6 = ld.transpose_conv_layer(conv5, 3, 2, 8, "decode_6")
            # conv7
            conv7 = ld.transpose_conv_layer(conv6, 3, 1, 8, "decode_7")
            # x_1
            x_1 = ld.transpose_conv_layer(conv7, 3, 2, 3, "decode_8",
                                          True)  # set activation to linear
            x_unwrap.append(x_1)
            # set reuse to true after first go
            if i == 0:
                tf.get_variable_scope().reuse_variables()

        # pack them all together
        x_unwrap = tf.pack(x_unwrap)
        x_unwrap = tf.transpose(x_unwrap, [1, 0, 2, 3, 4])

        # this part will be used for generating video
        x_unwrap_gen = []
        for i in range(50):
            if i == 0:
                # conv1
                conv1 = ld.conv_layer(x[:, 0, :, :, :], 3, 2, 8, "encode_1")
            else:
                # conv1
                conv1 = ld.conv_layer(x_1_gen, 3, 2, 8, "encode_1")
            # conv2
            conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2")
            # conv3
            conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3")
            # conv4
            conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4")
            # conv5
            conv5 = ld.transpose_conv_layer(conv4, 1, 1, 8, "decode_5")
            # conv6
            conv6 = ld.transpose_conv_layer(conv5, 3, 2, 8, "decode_6")
            # conv7
            conv7 = ld.transpose_conv_layer(conv6, 3, 1, 8, "decode_7")
            # x_1
            x_1_gen = ld.transpose_conv_layer(conv7, 3, 2, 3, "decode_8",
                                              True)  # set activation to linear
            x_unwrap_gen.append(x_1_gen)

        # pack them generated ones
        x_unwrap_gen = tf.pack(x_unwrap_gen)
        x_unwrap_gen = tf.transpose(x_unwrap_gen, [1, 0, 2, 3, 4])

        # calc total loss (compare x_t to x_t+1)
        loss = tf.nn.l2_loss(x[:, 1:, :, :, :] - x_unwrap)
        tf.scalar_summary('loss', loss)

        # training
        train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss)

        # List of all Variables
        variables = tf.all_variables()

        # Build a saver
        saver = tf.train.Saver(tf.all_variables())

        # Summary op
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session()

        # init if this is the very time training
        print("init network from scratch")
        sess.run(init)

        # Summary op
        graph_def = sess.graph.as_graph_def(add_shapes=True)
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                                graph_def=graph_def)

        for step in range(FLAGS.max_step):
            dat = generate_bouncing_ball_sample(FLAGS.batch_size,
                                                FLAGS.seq_length, 32,
                                                FLAGS.num_balls)
            t = time.time()
            _, loss_r = sess.run([train_op, loss],
                                 feed_dict={
                                     x: dat,
                                     keep_prob: FLAGS.keep_prob
                                 })
            elapsed = time.time() - t

            if step % 100 == 0 and step != 0:
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           x: dat,
                                           keep_prob: FLAGS.keep_prob
                                       })
                summary_writer.add_summary(summary_str, step)
                print("time per batch is " + str(elapsed))
                print(step)
                print(loss_r)

            assert not np.isnan(loss_r), 'Model diverged with loss = NaN'

            if step % 1000 == 0:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                print("saved to " + FLAGS.train_dir)

                # make video
                print("now generating video!")
                video = cv2.VideoWriter()
                success = video.open("generated_conv_video.mov", fourcc, 4,
                                     (180, 180), True)
                dat_gif = dat
                ims = sess.run([x_unwrap_gen],
                               feed_dict={
                                   x: dat_gif,
                                   keep_prob: FLAGS.keep_prob
                               })
                ims = ims[0][0]
                print(ims.shape)
                for i in range(50):
                    x_1_r = np.uint8(np.maximum(ims[i, :, :, :], 0) * 255)
                    new_im = cv2.resize(x_1_r, (180, 180))
                    video.write(new_im)
                video.release()
Ejemplo n.º 15
0
def train(train_data, validate_data, test_data):
    """Train ring_net for a number of steps."""
    with tf.Graph().as_default():
        # make inputs
        x = tf.placeholder(tf.float32, [
            None, FLAGS.seq_length, FLAGS.input_height, FLAGS.input_width,
            FLAGS.input_channel
        ])

        # possible dropout inside
        #keep_prob = tf.placeholder("float")
        #x_dropout = tf.nn.dropout(x, keep_prob)

        # create network
        x_unwrap = []

        with tf.variable_scope('conv_lstm',
                               initializer=tf.random_uniform_initializer(
                                   -.01, 0.1)):
            # BasicConvLSTMCell: (shape, filter_size, num_features, state_is_tuple)
            cell_1 = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3],
                                                         64,
                                                         state_is_tuple=True)
            # new_state: [batch_size, 16, 16, 64] for c and h
            new_state_1 = cell_1.zero_state(FLAGS.batch_size)

            # cell_2
            cell_2 = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3],
                                                         64,
                                                         state_is_tuple=True)
            new_state_2 = cell_2.zero_state(FLAGS.batch_size)

            # cell_3
            cell_3 = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3],
                                                         64,
                                                         state_is_tuple=True)
            #new_state_3 = cell_3.zero_state(FLAGS.batch_size)

            # cell_3
            cell_4 = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3],
                                                         64,
                                                         state_is_tuple=True)
            #new_state_4 = cell_3.zero_state(FLAGS.batch_size)

        # conv network
        # conv_layer: (input, kernel_size, stride, num_features, scope_name_idx, linear, reuseL)
        # BasicConvLSTMCell: __call__(inputs, state, scope, reuseL)
        reuseL = None
        with tf.variable_scope(tf.get_variable_scope()) as scope:
            for i in xrange(FLAGS.seq_length - 1):
                # -------------------------- encode ------------------------
                # conv1
                # [16, 64, 64, 1] -> [16, 32, 32, 8]
                if i < FLAGS.seq_start:
                    conv1 = ld.conv_layer(x[:, i, :, :, :],
                                          3,
                                          2,
                                          8,
                                          "encode_1",
                                          reuseL=reuseL)
                else:
                    conv1 = ld.conv_layer(x[:, i, :, :, :],
                                          3,
                                          2,
                                          8,
                                          "encode_1",
                                          reuseL=reuseL)
                # conv2
                # [16, 32, 32, 8] -> [16, 16, 16, 16]
                print conv1.get_shape().as_list()
                conv2 = ld.conv_layer(conv1,
                                      3,
                                      2,
                                      16,
                                      "encode_2",
                                      reuseL=reuseL)

                # convLSTM 3
                # [16, 16, 16, 16] -> [16, 16, 16, 64]
                print conv2.get_shape().as_list()
                y_0 = conv2
                print y_0.get_shape().as_list()
                y_1, new_state_1 = cell_1(y_0,
                                          new_state_1,
                                          "encode_3_convLSTM_1",
                                          reuseL=reuseL)
                # convLSTM 4
                # [16, 16, 16, 64] -> [16, 16, 16, 64]
                print y_1.get_shape().as_list()
                y_2, new_state_2 = cell_2(y_1,
                                          new_state_2,
                                          "encode_4_convLSTM_2",
                                          reuseL=reuseL)
                print y_2.get_shape().as_list()

                # ------------------------ decode -------------------------
                # convLSTM 5
                # [16, 16, 16, 64] -> [16, 16, 16, 64]
                # copy the initial states and cell outputs from convLSTM 4
                if i == 0:
                    new_state_3 = new_state_2
                y_3, new_state_3 = cell_3(y_2,
                                          new_state_3,
                                          "encode_5_convLSTM_3",
                                          reuseL=reuseL)
                # convLSTM 6
                # [16, 16, 16, 64] -> [16, 16, 16, 64]
                print y_3.get_shape().as_list()
                if i == 0:
                    new_state_4 = new_state_1
                y_4, new_state_4 = cell_4(y_3,
                                          new_state_4,
                                          "encode_6_convLSTM_4",
                                          reuseL=reuseL)
                print y_4.get_shape().as_list()
                # conv7
                # [16, 16, 16, 64] -> [16, 32, 32, 8]
                conv6 = y_4
                conv7 = ld.transpose_conv_layer(conv6,
                                                3,
                                                2,
                                                8,
                                                "decode_7",
                                                reuseL=reuseL)
                # x_1
                # [16, 32, 32, 8] -> [16, 64, 64, 1]
                print conv7.get_shape().as_list()
                x_1 = ld.transpose_conv_layer(
                    conv7, 3, 2, 1, "decode_8", linear=True,
                    reuseL=reuseL)  # set activation to linear
                if i >= FLAGS.seq_start - 1:
                    x_unwrap.append(x_1)
                # set reuse to true after first go
                if i == 0:
                    #tf.get_variable_scope().reuse_variables()
                    reuseL = True

        # stack them all together
        x_unwrap = tf.stack(x_unwrap)
        x_unwrap = tf.transpose(x_unwrap, [1, 0, 2, 3, 4])

        # calc total loss (compare x_t to x_t+1)
        loss = tf.nn.l2_loss(x[:, FLAGS.seq_start:, :, :, :] -
                             x_unwrap[:, :, :, :, :])
        tf.summary.scalar('loss', loss)

        # training
        train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss)

        # List of all Variables
        #variables = tf.global_variables()

        # Build a saver
        saver = tf.train.Saver(tf.global_variables())

        # Summary op
        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # --------------------- Start running operations on the Graph ------------------------
        sess = tf.Session()

        # init if this is the very time training
        print("init network from scratch")
        sess.run(init)

        # Summary op
        #graph_def = sess.graph.as_graph_def(add_shapes=True)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        # ------------------------------- training for train_data -------------------------------------
        print("now training step:")

        for step in xrange(FLAGS.max_step):
            dat = get_batch_data(train_data)
            t = time.time()
            _, loss_r = sess.run([train_op, loss], feed_dict={x: dat})
            elapsed = time.time() - t

            if step % 100 == 0 and step != 0:
                summary_str = sess.run(summary_op, feed_dict={x: dat})
                summary_writer.add_summary(summary_str, step)
                #print("time per batch is " + str(elapsed))
                print("at step " + str(step) + ", loss is " + str(loss_r))
                #print(loss_r)

            assert not np.isnan(loss_r), 'Model diverged with loss = NaN'

            if step % 1000 == 0 and step != 0:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                print("saved to " + FLAGS.train_dir)

                # -------------------------- validation for each 1000 steps ---------------------
                print("at step " + str(step) + ", validate...")
                v_whole_loss = 0
                # [all_size, seq_length, in_height, in_width, in_channel]
                all_batch_v = np.zeros(
                    (validate_data.shape[0] - FLAGS.seq_length,
                     FLAGS.seq_length, FLAGS.input_height, FLAGS.input_width,
                     FLAGS.input_channel),
                    dtype=np.float32)
                #all_batch_v = validate_data[v_i, v_i:v_i+FLAGS.seq_length, :, :, :]
                #while v_i <validate_data.shape[0]:
                for v_i in xrange(validate_data.shape[0] - FLAGS.seq_length):
                    all_batch_v[v_i] = validate_data[v_i:v_i +
                                                     FLAGS.seq_length, :, :, :]
                v_i = 0
                while v_i < all_batch_v.shape[0] - FLAGS.batch_size:
                    # [batch_size, seq_length, in_height, in_width, channel]
                    dat_v = all_batch_v[v_i:v_i + FLAGS.batch_size, :, :, :, :]
                    v_loss = sess.run([loss], feed_dict={x: dat_v})
                    v_whole_loss += v_loss
                print("validation loss: " + str(v_whole_loss))

        # ---------------------------- for test data -------------------------------------
        print("now testing step:")
        t_whole_loss = 0
        all_batch_t = np.zeros(
            (test_data.shape[0] - FLAGS.seq_length, FLAGS.seq_length,
             FLAGS.input_height, FLAGS.input_width, FLAGS.input_channel),
            dtype=np.float32)
        for t_i in xrange(test_data.shape[0] - FLAGS.seq_length):
            all_batch_t[t_i] = test_data[t_i:t_i + FLAGS.seq_length, :, :, :]
        t_i = 0
        while t_i < all_batch_t.shape[0] - FLAGS.batch_size:
            dat_t = all_batch_t[t_i:t_i + FLAGS.batch_size, :, :, :, :]
            predict, t_loss = sess.run([x_unwrap, loss], feed_dict={x: dat_t})
            t_whole_loss += t_loss
        print("test loss: " + str(t_whole_loss))
        np.save('prediction.npy', pred)