示例#1
0
 def single_cell_fn(unit_type, num_units, dropout, mode, forget_bias=1.0):
     """Create an instance of a single RNN cell."""
     dropout = dropout if mode is True else 0.0
     if unit_type == "lstm":
         c = rnn_cell.LSTMCell(num_units, forget_bias=forget_bias, state_is_tuple=False)
     elif unit_type == "gru":
         c = rnn_cell.GRUCell(num_units)
     else:
         raise ValueError("Unknown unit type %s!" % unit_type)
     if dropout > 0.0:
         c = rnn_cell.DropoutWrapper(cell=c, input_keep_prob=(1.0 - dropout))
     return c
示例#2
0
def inference(x,
              y,
              n_batch,
              is_training,
              input_digits=None,
              output_digits=None,
              n_hidden=None,
              n_out=None):
    def weight_variable(shape):
        initial = tf.truncated_normal(shape, stddev=0.01)
        return tf.Variable(initial)

    def bias_variable(shape):
        initial = tf.zeros(shape, dtype=tf.float32)
        return tf.Variable(initial)

    def batch_normalization(shape, x):
        with tf.name_scope('batch_normalization'):
            eps = 1e-8
            # beta = tf.Variable(tf.zeros(shape))
            # gamma = tf.Variable(tf.ones(shape))
            mean, var = tf.nn.moments(x, [0, 1])
            # nom_batch = gamma * (x - mean) / tf.sqrt(var + eps) + beta
            nom_batch = (x - mean) / tf.sqrt(var + eps)
            # print(nom_batch[0], len(nom_batch[0]))
            return nom_batch

    encoder_forward = rnn_cell.GRUCell(n_hidden, reuse=tf.AUTO_REUSE)
    encoder_backward = rnn_cell.GRUCell(n_hidden, reuse=tf.AUTO_REUSE)
    encoder_outputs = []
    encoder_states = []

    # size = [batch_size][input_digits][input_len]
    x = tf.transpose(batch_normalization(input_digits, x), [1, 0, 2])
    x = tf.reshape(x, [-1, n_in])
    x = tf.split(x, input_digits, 0)
    # Encode

    # state = encoder.zero_state(n_batch, tf.float32)

    # with tf.variable_scope('Encoder'):
    #     for t in range(input_digits):
    #         if t > 0:
    #             tf.get_variable_scope().reuse_variables()
    #         (output, state) = encoder(batch_normalization(input_digits, x)[:, t, :], state)
    #         encoder_outputs.append(output)
    #         encoder_states.append(state)

    encoder_outputs, encoder_states_fw, encoder_states_bw = tf.nn.static_bidirectional_rnn(
        encoder_forward, encoder_backward, x, dtype=tf.float32)
    # encoder_outputs size = [time][batch][cell_fw.output_size + cell_bw.output_size]
    # encoder_states_fw, encoder_states_bw is final state
    # Decode


    AttentionMechanism = seq2seq.BahdanauAttention(num_units=num_units,
                                                    memory=tf.reshape(encoder_outputs, \
                                                        [n_batch, input_digits, n_hidden * 2])
                                                    )
    # when use bidirectional, n_hidden * 2
    # tf.reshape(encoder_outputs, n_batch, input_digits, ),
    # memory_sequence_length = input_digits)
    # normalize=True)

    decoder_1 = rnn_cell.GRUCell(n_hidden, reuse=tf.AUTO_REUSE)
    # decoder_2 = rnn_cell.GRUCell(n_hidden, reuse = tf.AUTO_REUSE)

    decoder_1 = seq2seq.AttentionWrapper(
        decoder_1,
        attention_mechanism=AttentionMechanism,
        attention_layer_size=attention_layer_size,
        output_attention=False)
    # initial_cell_state = encoder_states[-1])こいつが悪い

    # decoder_2= seq2seq.AttentionWrapper(decoder_2,
    #                                    attention_mechanism = AttentionMechanism,
    #                                    attention_layer_size = 50,
    #                                    output_attention = False,
    #                                    name = 'att_lay_2')

    state_1 = decoder_1.zero_state(n_batch, tf.float32)\
        .clone(cell_state=encoder_states_fw)

    # state_2 = decoder_2.zero_state(n_batch, tf.float32)
    # .clone(cell_state=tf.reshape(encoder_states_bw[-1], [n_batch, n_hidden]))

    # state = encoder_states[-1]
    # decoder_outputs = tf.reshape(encoder_outputs[-1, :, :], [n_batch, 1])
    # [input_len, n_batch, n_hidden]
    decoder_1_outputs = tf.slice(encoder_outputs, [input_digits - 2, 0, 0],
                                 [1, n_batch, n_hidden])
    # decoder_2_outputs = tf.slice(encoder_outputs, [input_digits-2, 0, n_hidden], [1, n_batch, n_hidden])
    # decoder_2_outputs = encoder_outputs[:, :, n_hidden:][-1]
    # decoder_outputs = [encoder_outputs[-1]]

    # 出力層の重みとバイアスを事前に定義
    V_hid_1 = weight_variable([n_hidden, n_out])
    c_hid_1 = bias_variable([n_out])

    V_hid_2 = weight_variable([n_hidden, n_out])
    c_hid_2 = bias_variable([n_out])

    V_out = weight_variable([n_hidden, n_out])
    c_out = bias_variable([n_out])

    fc_outputs = []

    # decoder = seq2seq.BasicDecoder(cell = decoder,
    #                                 heiper = helper,
    #                                 initial_state=state,
    #                                 )

    elems = tf.convert_to_tensor([1, 0])
    samples = tf.multinomial(tf.log([[tchr_frcng_thr, 1 - tchr_frcng_thr]]),
                             1)  # note log-prob

    with tf.variable_scope('Decoder'):
        for t in range(1, output_digits):
            if t > 1:
                tf.get_variable_scope().reuse_variables()
                # tf.get_variable_scope().reuse_variables()

            if is_training is True:
                cell_input_bin = elems[tf.cast(samples[0][0], tf.int32)]
                # bool = tf.equal(cell_input_bin, 1)
                t_const = tf.const(t)
                cell_input = tf.case(
                    {
                        tf.equal(cell_input_bin, 1):
                        lambda: batch_normalization(output_digits, y)[:, t -
                                                                      1, :],
                        tf.equal(t_const, 1):
                        lambda: tf.matmul(decoder_1_outputs[-1], V_hid_1) +
                        c_hid_1
                    },
                    default=lambda: output_1)
                # cell_input_bin = np.randam.choice([1, 0],p=[tchr_frcng_thr, 1 - tchr_frcng_thr])
                #
                # if cell_input_bin==1:
                #     cell_input = batch_normalization(output_digits, y)[:, t-1, :]
                #
                # elif t == 1:
                #     cell_input = tf.matmul(decoder_1_outputs[-1], V_hid_1) + c_hid_1
                #
                # else:
                #     cell_input = output_1

                (output_1, state_1) = decoder_1(cell_input, state_1)
                # (output_2, state_2) = decoder_2(batch_normalization(output_digits, y)[:, t-1, :], state_2)
            else:
                # 直前の出力を求める
                out_1 = tf.matmul(decoder_1_outputs[-1],
                                  V_hid_1) + c_hid_1  #to hidden layer
                # out_2 = tf.matmul(decoder_2_outputs[-1], V_hid_2) + c_hid_2#to hidden layer
                # fc_out = tf.matmul(tf.concat([decoder_1_outputs[-1], decoder_2_outputs[-1]], 1), V_out) + c_out
                #forecast data

                # elems = decoder_outputs[-1], V , c
                # out = tf.map_fn(lambda x: x[0] * x[1] + x[2], elems)
                # out = decoder_outputs
                fc_outputs.append(out_1)
                (output_1, state_1) = decoder_1(out_1, state_1)
                # (output_2, state_2) = decoder_2(out_2, state_2)

            # decoder_outputs.append(output)
            decoder_1_outputs = tf.concat([
                decoder_1_outputs,
                tf.reshape(output_1, [1, n_batch, n_hidden])
            ],
                                          axis=0)
            # decoder_2_outputs = tf.concat([decoder_2_outputs, tf.reshape(output_2, [1, n_batch, n_hidden])], axis = 0)
            # decoder_outputs = tf.concat([decoder_outputs, output], 1)
    if is_training is True:
        output = tf.reshape(tf.concat(decoder_1_outputs, axis=1),
                            [-1, output_digits, n_hidden])
        with tf.name_scope('check'):
            linear = tf.einsum(
                'ijk,kl->ijl',
                output,
                V_out,
            ) + c_out
            return linear
    else:
        # 最後の出力を求める
        fc_out = tf.matmul(tf.concat(decoder_1_outputs[-1], 1), V_out) + c_out
        fc_outputs.append(fc_out)

        output = tf.reshape(tf.concat(fc_outputs, axis=1),
                            [-1, output_digits, n_out])
        return output
示例#3
0
    def __init__(self, pd, generator=None):
        self.pd = pd
        self.graph = tf.Graph()
        with self.graph.as_default():
            global_step = tf.Variable(0, name='global_step', trainable=False)
            if generator:
                print('Running synthetic experiment')
                data_nd_norm = generator()
                data_nd = data_nd_norm
            else:
                data_mean = tf.constant(pd['power_handler'].mean,
                                        tf.float32,
                                        name='data_mean')
                data_std = tf.constant(pd['power_handler'].std,
                                       tf.float32,
                                       name='data_std')

                data_nd = tf.placeholder(
                    tf.float32, [pd['batch_size'], pd['input_samples'], 1])
                self.data_nd = data_nd
                data_nd_norm = (data_nd - data_mean) / data_std

            print('data_nd_shape', data_nd_norm.shape)
            dtype = tf.float32
            data_encoder_time, data_decoder_time = tf.split(
                data_nd_norm,
                [pd['input_samples'] - pd['pred_samples'], pd['pred_samples']],
                axis=1)

            if pd['fft']:
                dtype = tf.complex64
                if pd['window_function'] == 'learned_gaussian':
                    window = wl.gaussian_window(pd['window_size'])
                elif pd['window_function'] == 'learned_plank':
                    window = wl.plank_taper(pd['window_size'])
                elif pd['window_function'] == 'learned_tukey':
                    window = wl.tukey_window(pd['window_size'])
                elif pd['window_function'] == 'learned_gauss_plank':
                    window = wl.gauss_plank_window(pd['window_size'])
                else:
                    window = scisig.get_window(window=pd['window_function'],
                                               Nx=pd['window_size'])
                    window = tf.constant(window, tf.float32)

                def transpose_stft_squeeze(in_data, window, pd):
                    '''
                    Compute a windowed stft and do low pass filtering if
                    necessary.
                    '''
                    tmp_in_data = tf.transpose(in_data, [0, 2, 1])
                    in_data_fft = eagerSTFT.stft(tmp_in_data, window,
                                                 pd['window_size'],
                                                 pd['overlap'])
                    freqs = int(in_data_fft.shape[-1])
                    idft_shape = in_data_fft.shape.as_list()
                    if idft_shape[1] == 1:
                        # in the one dimensional case squeeze the dim away.
                        in_data_fft = tf.squeeze(in_data_fft, axis=1)
                        if pd['fft_compression_rate']:
                            compressed_freqs = int(freqs /
                                                   pd['fft_compression_rate'])
                            print('fft_compression_rate',
                                  pd['fft_compression_rate'], 'freqs', freqs,
                                  'compressed_freqs', compressed_freqs)
                            # remove frequencies from the last dimension.
                            return in_data_fft[
                                ..., :compressed_freqs], idft_shape, freqs
                        else:
                            return in_data_fft, idft_shape, freqs
                    else:
                        # arrange as batch time freq dim
                        in_data_fft = tf.transpose(in_data_fft, [0, 2, 3, 1])
                        raise NotImplementedError

                data_encoder_freq, _, enc_freqs = \
                    transpose_stft_squeeze(data_encoder_time, window, pd)
                data_decoder_freq, dec_shape, dec_freqs = \
                    transpose_stft_squeeze(data_decoder_time, window, pd)
                assert enc_freqs == dec_freqs, 'encoder-decoder frequencies must agree'
                fft_pred_samples = data_decoder_freq.shape[1].value

            elif pd['linear_reshape']:
                encoder_time_steps = data_encoder_time.shape[1].value // pd[
                    'step_size']
                data_encoder_time = tf.reshape(
                    data_encoder_time,
                    [pd['batch_size'], encoder_time_steps, pd['step_size']])
                if pd['downsampling'] > 1:
                    data_encoder_time = data_encoder_time[:, :, ::
                                                          pd['downsampling']]
                decoder_time_steps = data_decoder_time.shape[1].value // pd[
                    'step_size']

            if pd['cell_type'] == 'cgRNN':
                if pd['stiefel']:
                    cell = ccell.StiefelGatedRecurrentUnit(
                        pd['num_units'],
                        num_proj=pd['num_proj'],
                        complex_input=pd['fft'],
                        complex_output=pd['fft'],
                        activation=ccell.mod_relu,
                        stiefel=pd['stiefel'])
                else:
                    cell = ccell.StiefelGatedRecurrentUnit(
                        pd['num_units'],
                        num_proj=pd['num_proj'],
                        complex_input=pd['fft'],
                        complex_output=pd['fft'],
                        activation=ccell.hirose,
                        stiefel=pd['stiefel'])
                cell = RnnInputWrapper(1.0, cell)
                if pd['use_residuals']:
                    cell = ResidualWrapper(cell=cell)
            elif pd['cell_type'] == 'gru':
                gru = rnn_cell.GRUCell(pd['num_units'])
                if pd['fft'] is True:
                    dtype = tf.float32
                    # concatenate real and imaginary parts.
                    data_encoder_freq = tf.concat([
                        tf.real(data_encoder_freq),
                        tf.imag(data_encoder_freq)
                    ],
                                                  axis=-1)
                    cell = LinearProjWrapper(pd['num_proj'] * 2,
                                             cell=gru,
                                             sample_prob=pd['sample_prob'])
                else:
                    cell = LinearProjWrapper(pd['num_proj'],
                                             cell=gru,
                                             sample_prob=pd['sample_prob'])
                cell = RnnInputWrapper(1.0, cell)
                if pd['use_residuals']:
                    cell = ResidualWrapper(cell=cell)
            else:
                print('cell type not supported.')

            if pd['fft']:
                encoder_in = data_encoder_freq
            else:
                encoder_in = data_encoder_time

            with tf.variable_scope("encoder_decoder") as scope:

                zero_state = cell.zero_state(pd['batch_size'], dtype=dtype)
                zero_state = LSTMStateTuple(encoder_in[:, 0, :], zero_state[1])
                # debug_here()
                encoder_out, encoder_state = tf.nn.dynamic_rnn(
                    cell, encoder_in, initial_state=zero_state, dtype=dtype)
                if not pd['fft']:
                    if pd['linear_reshape']:
                        decoder_in = tf.zeros(
                            [pd['batch_size'], decoder_time_steps, 1])
                    else:
                        decoder_in = tf.zeros(
                            [pd['batch_size'], pd['pred_samples'], 1])
                    encoder_state = LSTMStateTuple(data_encoder_time[:, -1, :],
                                                   encoder_state[-1])
                else:
                    freqs = encoder_in.shape[-1].value
                    decoder_in = tf.zeros(
                        [pd['batch_size'], fft_pred_samples, freqs],
                        dtype=dtype)

                    encoder_state = LSTMStateTuple(data_encoder_freq[:, -1, :],
                                                   encoder_state[-1])
                cell.close()
                scope.reuse_variables()
                # debug_here()
                decoder_out, _ = tf.nn.dynamic_rnn(cell,
                                                   decoder_in,
                                                   initial_state=encoder_state,
                                                   dtype=dtype)

                if pd['fft'] and pd['cell_type'] == 'gru':
                    # assemble complex output.
                    decoder_freqs_t2 = decoder_out.shape[-1].value
                    decoder_out = tf.complex(
                        decoder_out[:, :, :int(decoder_freqs_t2 / 2)],
                        decoder_out[:, :, int(decoder_freqs_t2 / 2):])
                    encoder_out = tf.complex(
                        encoder_out[:, :, :int(decoder_freqs_t2 / 2)],
                        encoder_out[:, :, int(decoder_freqs_t2 / 2):])
                    encoder_in = tf.complex(
                        encoder_in[:, :, :int(decoder_freqs_t2 / 2)],
                        encoder_in[:, :, int(decoder_freqs_t2 / 2):])

            if pd['fft']:
                if (pd['freq_loss'] == 'complex_abs') \
                   or (pd['freq_loss'] == 'complex_abs_time'):
                    diff = data_decoder_freq - decoder_out
                    prd_loss = tf.abs(tf.real(diff)) + tf.abs(tf.imag(diff))
                    # tf.summary.histogram('complex_abs', prd_loss)
                    # tf.summary.histogram('log_complex_abs', tf.log(prd_loss))
                    prd_loss = tf.reduce_mean(prd_loss)
                    tf.summary.scalar('f_complex_abs', prd_loss)
                if (pd['freq_loss'] == 'complex_square') \
                   or (pd['freq_loss'] == 'complex_square_time'):
                    diff = data_decoder_freq - decoder_out
                    prd_loss = tf.real(diff) * tf.real(diff) + tf.imag(
                        diff) * tf.imag(diff)
                    # tf.summary.histogram('complex_square', prd_loss)
                    prd_loss = tf.reduce_mean(prd_loss)
                    tf.summary.scalar('f_complex_square', prd_loss)

                def expand_dims_and_transpose(input_tensor, pd, freqs):
                    output = tf.expand_dims(input_tensor, 1)
                    if pd['fft_compression_rate']:
                        zero_coeffs = freqs - int(input_tensor.shape[-1])
                        zero_stack = tf.zeros(
                            output.shape[:-1].as_list() + [zero_coeffs],
                            tf.complex64)
                        output = tf.concat([output, zero_stack], -1)
                    return output

                decoder_out = expand_dims_and_transpose(
                    decoder_out, pd, dec_freqs)
                decoder_out = eagerSTFT.istft(decoder_out,
                                              window,
                                              nperseg=pd['window_size'],
                                              noverlap=pd['overlap'],
                                              epsilon=pd['epsilon'])
                # data_encoder_gt = expand_dims_and_transpose(encoder_in, pd, enc_freqs)
                decoder_out = tf.transpose(decoder_out, [0, 2, 1])
            elif pd['linear_reshape']:
                if pd['downsampling'] > 1:
                    decoder_out_t = tf.transpose(decoder_out, [0, 2, 1])
                    decoder_out_t = eagerSTFT.interpolate(
                        decoder_out_t, pd['step_size'])
                    decoder_out = tf.transpose(decoder_out_t, [0, 2, 1])
                decoder_out = tf.reshape(
                    decoder_out, [pd['batch_size'], pd['pred_samples'], 1])

            time_loss = tf.losses.mean_squared_error(
                tf.real(data_decoder_time),
                tf.real(decoder_out[:, :pd['pred_samples'], :]))
            if not pd['fft']:
                loss = time_loss
            else:
                if (pd['freq_loss'] == 'ad_time') or \
                   (pd['freq_loss'] == 'log_mse_time') or \
                   (pd['freq_loss'] == 'mse_time') or \
                   (pd['freq_loss'] == 'log_mse_mse_time') or \
                   (pd['freq_loss'] == 'complex_square_time') or \
                   (pd['freq_loss'] == 'complex_abs_time'):
                    print('using freq and time based loss.')
                    lambda_t = 1
                    loss = prd_loss * lambda_t + time_loss
                    tf.summary.scalar('lambda_t', lambda_t)
                elif (pd['freq_loss'] is None):
                    print('time loss only')
                    loss = time_loss
                else:
                    loss = prd_loss

            learning_rate = tf.train.exponential_decay(
                pd['init_learning_rate'],
                global_step,
                pd['decay_steps'],
                pd['decay_rate'],
                staircase=True)
            tf.summary.scalar('learning_rate', learning_rate)

            if (pd['cell_type'] == 'orthogonal' or pd['cell_type'] == 'cgRNN') \
               and (pd['stiefel'] is True):
                optimizer = co.RMSpropNatGrad(learning_rate,
                                              global_step=global_step)
            else:
                optimizer = tf.train.RMSPropOptimizer(learning_rate)
            gvs = optimizer.compute_gradients(loss)

            with tf.variable_scope("clip_grads"):
                capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var)
                              for grad, var in gvs]

            # grad_summary = tf.histogram_summary(grads)
            # training_op = optimizer.minimize(loss, global_step=global_step)
            self.training_op = optimizer.apply_gradients(
                capped_gvs, global_step=global_step)
            tf.summary.scalar('time_loss', time_loss)
            tf.summary.scalar('training_loss', loss)

            self.init_op = tf.global_variables_initializer()
            self.summary_sum = tf.summary.merge_all()
            self.total_parameters = compute_parameter_total(
                tf.trainable_variables())
            self.saver = tf.train.Saver()
            self.loss = loss
            self.global_step = global_step
            self.decoder_out = decoder_out
            self.data_nd = data_nd
            self.data_encoder_time = data_encoder_time
            self.data_decoder_time = data_decoder_time
            if pd['fft']:
                self.window = window
    def inference(x,
                  y,
                  n_batch,
                  is_training,
                  input_digits=None,
                  output_digits=None,
                  n_hidden=None,
                  n_out=None):
        def weight_variable(shape):
            initial = tf.truncated_normal(shape, stddev=0.01)
            return tf.Variable(initial)

        def bias_variable(shape):
            initial = tf.zeros(shape, dtype=tf.float32)
            return tf.Variable(initial)

        def batch_normalization(shape, x):
            with tf.name_scope('batch_normalization'):
                eps = 1e-8
                # beta = tf.Variable(tf.zeros(shape))
                # gamma = tf.Variable(tf.ones(shape))
                mean, var = tf.nn.moments(x, [0, 1])
                # nom_batch = gamma * (x - mean) / tf.sqrt(var + eps) + beta
                nom_batch = (x - mean) / tf.sqrt(var + eps)
                # print(nom_batch[0], len(nom_batch[0]))
                return nom_batch

        encoder = rnn_cell.GRUCell(n_hidden)
        encoder_outputs = []
        encoder_states = []

        # Encode
        # encoder = cudnn_rnn.CudnnGRU(
        #                             num_layers=1,
        #                             num_units=int(n_hidden),
        #                             input_mode='auto_select',
        #                             # direction='bidirectional',
        #                             dtype=tf.float32)

        state = encoder.zero_state(n_batch, tf.float32)

        # [input_digits, n_batch, 1], [1, n_batch, n_hidden]
        # encoder_outputs, encoder_states = \
        #     encoder(tf.reshape(batch_normalization(input_digits, x), \
        #                 [input_digits, n_batch, n_in]),
        #             # initial_state = state,
        #             training = True
        #             )

        with tf.variable_scope('Encoder'):
            for t in range(input_digits):
                if t > 0:
                    tf.get_variable_scope().reuse_variables()
                (output, state) = encoder(
                    batch_normalization(input_digits, x)[:, t, :], state)
                encoder_outputs.append(output)
                encoder_states.append(state)

        # encoder = seq2seq.AttentionWrapper(encoder,
        #                                     attention_mechanism = AttentionMechanism,
        #                                     attention_layer_size = 128,
        #                                     initial_cell_state = \
        #                                     AttentionWrapper.zero_state(n_batch, tf.float32))

        # Decode


        AttentionMechanism = seq2seq.BahdanauAttention(num_units=100,
                                                        memory=tf.reshape(encoder_outputs, \
                                                            [n_batch, input_digits, n_hidden * 1])
                                                        )
        # when use bidirectional, n_hidden * 2
        # tf.reshape(encoder_outputs, n_batch, input_digits, ),
        # memory_sequence_length = input_digits)
        # normalize=True)

        decoder = rnn_cell.GRUCell(n_hidden)
        decoder = seq2seq.AttentionWrapper(
            decoder,
            attention_mechanism=AttentionMechanism,
            attention_layer_size=50,
            output_attention=False)
        # initial_cell_state = encoder_states[-1])こいつが悪い


        state = decoder.zero_state(n_batch, tf.float32)\
            .clone(cell_state=tf.reshape(encoder_states[-1], [n_batch, n_hidden]))
        # state = encoder_states[-1]
        # decoder_outputs = tf.reshape(encoder_outputs[-1, :, :], [n_batch, 1])
        # [input_len, n_batch, n_hidden]
        # なんでかスライスだけエラーなし?
        decoder_outputs = [encoder_outputs[-1]]
        # decoder_outputs = [encoder_outputs[-1]]
        # 出力層の重みとバイアスを事前に定義
        V = weight_variable([n_hidden, n_out])
        c = bias_variable([n_out])
        outputs = []

        # decoder = seq2seq.BasicDecoder(cell = decoder,
        #                                 heiper = helper,
        #                                 initial_state=state,
        #                                 )

        with tf.variable_scope('Decoder'):
            for t in range(1, output_digits):
                if t > 1:
                    tf.get_variable_scope().reuse_variables()

                if is_training is True:
                    (output, state) = decoder(
                        batch_normalization(output_digits, y)[:, t - 1, :],
                        state)
                else:
                    # 直前の出力を求める
                    out = tf.matmul(decoder_outputs[-1], V) + c
                    # elems = decoder_outputs[-1], V , c
                    # out = tf.map_fn(lambda x: x[0] * x[1] + x[2], elems)
                    # out = decoder_outputs
                    outputs.append(out)
                    (output, state) = decoder(out, state)

                # decoder_outputs.append(output)
                decoder_outputs = tf.concat([
                    decoder_outputs,
                    tf.reshape(output, [1, n_batch, n_hidden])
                ],
                                            axis=0)
                # decoder_outputs = tf.concat([decoder_outputs, output], 1)
        if is_training is True:
            output = tf.reshape(tf.concat(decoder_outputs, axis=1),
                                [-1, output_digits, n_hidden])
            with tf.name_scope('check'):
                linear = tf.einsum(
                    'ijk,kl->ijl',
                    output,
                    V,
                ) + c
                return linear
        else:
            # 最後の出力を求める
            linear = tf.matmul(decoder_outputs[-1], V) + c
            outputs.append(linear)

            output = tf.reshape(tf.concat(outputs, axis=1),
                                [-1, output_digits, n_out])
            return output
    def __init__(self, args, infer=False):
        """
        Initialisation function for the class Model.
        Params:
        args: Contains arguments required for the Model creation
        """

        # If sampling new trajectories, then infer mode
        if infer:
            # Infer one position at a time
            args.batch_size = 1
            args.obs_length = 1
            args.pred_length = 1

        # Store the arguments
        self.args = args

        # placeholders for the input data and the target data
        # A sequence contains an ordered set of consecutive frames
        # Each frame can contain a maximum of 'args.maxNumPeds' number of peds
        # For each ped we have their (pedID, x, y) positions as input
        self.input_data = tf.placeholder(tf.float32,
                                         [args.obs_length, args.maxNumPeds, 3],
                                         name="input_data")
        # target data would be the same format as input_data except with one time-step ahead
        self.target_data = tf.placeholder(
            tf.float32, [args.obs_length, args.maxNumPeds, 3],
            name="target_data")
        # Learning rate
        self.lr = tf.placeholder(tf.float32, shape=None, name="learning_rate")
        self.final_lr = tf.placeholder(tf.float32,
                                       shape=None,
                                       name="final_learning_rate")
        self.training_epoch = tf.placeholder(tf.float32,
                                             shape=None,
                                             name="training_epoch")
        # keep prob
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

        cells = []
        for _ in range(args.num_layers):
            # Initialize a BasicLSTMCell recurrent unit
            # args.rnn_size contains the dimension of the hidden state of the LSTM
            # cell = rnn_cell.BasicLSTMCell(args.rnn_size, name='basic_lstm_cell', state_is_tuple=False)

            # Construct the basicLSTMCell recurrent unit with a dimension given by args.rnn_size
            if args.model == "lstm":
                with tf.name_scope("LSTM_cell"):
                    cell = rnn_cell.LSTMCell(args.rnn_size,
                                             state_is_tuple=False)

            elif args.model == "gru":
                with tf.name_scope("GRU_cell"):
                    cell = rnn_cell.GRUCell(args.rnn_size,
                                            state_is_tuple=False)

            if not infer and args.keep_prob < 1:
                cell = rnn_cell.DropoutWrapper(cell,
                                               output_keep_prob=self.keep_prob)

            cells.append(cell)

        # Multi-layer RNN construction, if more than one layer
        # cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=False)
        cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=False)

        # Store the recurrent unit
        self.cell = cell

        # Output size is the set of parameters (mu, sigma, corr)
        self.output_size = 5  # 2 mu, 2 sigma and 1 corr

        with tf.name_scope("learning_rate"):
            self.final_lr = self.lr * (self.args.decay_rate**
                                       self.training_epoch)

        self.define_embedding_and_output_layers(args)

        # Define LSTM states for each pedestrian
        with tf.variable_scope("LSTM_states"):
            self.LSTM_states = tf.zeros(
                [args.maxNumPeds, self.cell.state_size], name="LSTM_states")
            self.initial_states = tf.split(self.LSTM_states, args.maxNumPeds,
                                           0)
            # https://stackoverflow.com/a/41384913/2049763

        # Define hidden output states for each pedestrian
        with tf.variable_scope("Hidden_states"):
            self.output_states = tf.split(
                tf.zeros([args.maxNumPeds, self.cell.output_size]),
                args.maxNumPeds, 0)

        # List of tensors each of shape args.maxNumPeds x 3 corresponding to each frame in the sequence
        with tf.name_scope("frame_data_tensors"):
            frame_data = [
                tf.squeeze(input_, [0])
                for input_ in tf.split(self.input_data, args.obs_length, 0)
            ]

        with tf.name_scope("frame_target_data_tensors"):
            frame_target_data = [
                tf.squeeze(target_, [0])
                for target_ in tf.split(self.target_data, args.obs_length, 0)
            ]

        # Cost
        with tf.name_scope("Cost_related_stuff"):
            self.cost = tf.constant(0.0, name="cost")
            self.counter = tf.constant(0.0, name="counter")
            self.increment = tf.constant(1.0, name="increment")

        # Containers to store output distribution parameters
        with tf.name_scope("Distribution_parameters_stuff"):
            self.initial_output = tf.split(
                tf.zeros([args.maxNumPeds, self.output_size]), args.maxNumPeds,
                0)

        # Tensor to represent non-existent ped
        with tf.name_scope("Non_existent_ped_stuff"):
            nonexistent_ped = tf.constant(0.0, name="zero_ped")

        self.final_result = []
        # Iterate over each frame in the sequence
        for seq, frame in enumerate(frame_data):
            # print("Frame number", seq)
            final_result_ped = []
            current_frame_data = frame  # MNP x 3 tensor
            for ped in range(args.maxNumPeds):
                # pedID of the current pedestrian
                pedID = current_frame_data[ped, 0]
                # print("Pedestrian Number", ped)

                with tf.name_scope("extract_input_ped"):
                    # Extract x and y positions of the current ped
                    self.spatial_input = tf.slice(
                        current_frame_data, [ped, 1],
                        [1, 2])  # Tensor of shape (1,2)

                with tf.name_scope("embeddings_operations"):
                    # Embed the spatial input
                    embedded_spatial_input = tf.nn.relu(
                        tf.nn.xw_plus_b(self.spatial_input, self.embedding_w,
                                        self.embedding_b))

                # One step of LSTM
                with tf.variable_scope("LSTM") as scope:
                    if seq > 0 or ped > 0:
                        scope.reuse_variables()
                    self.output_states[ped], self.initial_states[
                        ped] = self.cell(embedded_spatial_input,
                                         self.initial_states[ped])

                # Apply the linear layer. Output would be a tensor of shape 1 x output_size
                with tf.name_scope("output_linear_layer"):
                    self.initial_output[ped] = tf.nn.xw_plus_b(
                        self.output_states[ped], self.output_w, self.output_b)

                with tf.name_scope("extract_target_ped"):
                    # Extract x and y coordinates of the target data
                    # x_data and y_data would be tensors of shape 1 x 1
                    [x_data, y_data] = tf.split(
                        tf.slice(frame_target_data[seq], [ped, 1], [1, 2]), 2,
                        1)
                    target_pedID = frame_target_data[seq][ped, 0]

                with tf.name_scope("get_coef"):
                    # Extract coef from output of the linear output layer
                    [o_mux, o_muy, o_sx, o_sy,
                     o_corr] = self.get_coef(self.initial_output[ped])
                    final_result_ped.append([o_mux, o_muy, o_sx, o_sy, o_corr])

                # Calculate loss for the current ped
                with tf.name_scope("calculate_loss"):
                    lossfunc = self.get_lossfunc(o_mux, o_muy, o_sx, o_sy,
                                                 o_corr, x_data, y_data)

                # If it is a non-existent ped, it should not contribute to cost
                # If the ped doesn't exist in the next frame, he/she should not contribute to cost as well
                with tf.name_scope("increment_cost"):
                    self.cost = tf.where(
                        tf.logical_or(tf.equal(pedID, nonexistent_ped),
                                      tf.equal(target_pedID, nonexistent_ped)),
                        self.cost, tf.add(self.cost, lossfunc))

                    self.counter = tf.where(
                        tf.logical_or(tf.equal(pedID, nonexistent_ped),
                                      tf.equal(target_pedID, nonexistent_ped)),
                        self.counter, tf.add(self.counter, self.increment))

            self.final_result.append(tf.stack(final_result_ped))
        # Compute the cost
        with tf.name_scope("mean_cost"):
            # Mean of the cost
            self.cost = tf.div(self.cost, self.counter)

        # Get trainable_variables
        tvars = tf.trainable_variables()

        # L2 loss
        l2 = args.lambda_param * sum(tf.nn.l2_loss(tvar) for tvar in tvars)
        self.cost = self.cost + l2

        # Get the final LSTM states
        self.final_states = tf.concat(self.initial_states, 0)
        # Get the final distribution parameters
        self.final_output = self.initial_output

        # initialize the optimizer with the given learning rate
        if args.optimizer == "RMSprop":
            optimizer = tf.train.RMSPropOptimizer(learning_rate=self.final_lr,
                                                  momentum=0.9)
        elif args.optimizer == "AdamOpt":
            # NOTE: Using RMSprop as suggested by Social LSTM instead of Adam as Graves(2013) does
            optimizer = tf.train.AdamOptimizer(self.final_lr)

        # How to apply gradient clipping in TensorFlow? https://stackoverflow.com/a/43486487/2049763
        #         # https://stackoverflow.com/a/40540396/2049763
        # TODO: (resolve) We are clipping the gradients as is usually done in LSTM
        # implementations. Social LSTM paper doesn't mention about this at all
        # Calculate gradients of the cost w.r.t all the trainable variables
        self.gradients = tf.gradients(self.cost, tvars)
        # self.gradients = optimizer.compute_gradients(self.cost, var_list=tvars)
        # Clip the gradients if they are larger than the value given in args
        self.clipped_gradients, _ = tf.clip_by_global_norm(
            self.gradients, args.grad_clip)

        # Train operator
        self.train_op = optimizer.apply_gradients(
            zip(self.clipped_gradients, tvars))

        self.grad_placeholders = []
        for var in tvars:
            self.grad_placeholders.append(tf.placeholder(var.dtype, var.shape))
        # Train operator
        self.train_op_2 = optimizer.apply_gradients(
            zip(self.grad_placeholders, tvars))