Exemple #1
0
    def build_model(self):
        input = tf.placeholder(tf.float32, [None, self.num_steps], name='inputs')  # input
        noise = tf.placeholder(tf.float32, [None, self.num_steps], name='noise')

        real_fake_label = tf.placeholder(tf.float32, [None, 2], name='real_fake_label')

        F_new_value = tf.placeholder(tf.float32, [None, self.K], name='F_new_value')
        # F = tf.Variable(tf.eye(self.batch_size,num_columns = self.K), trainable = False)
        F = tf.get_variable('F', shape=[self.batch_size, self.K],
                            initializer=tf.orthogonal_initializer(gain=1.0, seed=None, dtype=tf.float32),
                            trainable=False)

        # inputs has shape (batch_size, n_steps, embedding_size)
        inputs = tf.reshape(input, [-1, self.num_steps, self.embedding_size])
        noises = tf.reshape(noise, [-1, self.num_steps, self.embedding_size])

        # a list of 'n_steps' tenosrs, each has shape (batch_size, embedding_size)
        # encoder_inputs = utils._rnn_reformat(x = inputs, input_dims = self.embedding_size, n_steps = self.num_steps)

        # noise_input has shape (batch_size, n_steps, embedding_size)
        if self.denosing:
            print('Noise')
            noise_input = inputs + noises
        else:
            print('Non_noise')
            noise_input = inputs

        reverse_noise_input = tf.reverse(noise_input, axis=[1])
        decoder_inputs = utils._rnn_reformat(x=noise_input, input_dims=self.embedding_size, n_steps=self.num_steps)
        targets = utils._rnn_reformat(x=inputs, input_dims=self.embedding_size, n_steps=self.num_steps)

        if self.cell_type == 'LSTM':
            raise ValueError('LSTMs have not support yet!')

        elif self.cell_type == 'GRU':
            cell = tf.contrib.rnn.GRUCell(np.sum(self.hidden_size) * 2)

        cell = rnn_cell_extensions.LinearSpaceDecoderWrapper(cell, self.embedding_size)

        lf = None
        if self.sample_loss:
            print
            'Sample Loss'

            def lf(prev, i):
                return prev

        # encoder_output has shape 'layer' list of tensor [batch_size, n_steps, hidden_size]
        with tf.variable_scope('fw'):
            _, encoder_output_fw = drnn.drnn_layer_final(noise_input, self.hidden_size, self.dilations, self.num_steps,
                                                         self.embedding_size, self.cell_type)

        with tf.variable_scope('bw'):
            _, encoder_output_bw = drnn.drnn_layer_final(reverse_noise_input, self.hidden_size, self.dilations,
                                                         self.num_steps, self.embedding_size, self.cell_type)

        if self.cell_type == 'LSTM':
            raise ValueError('LSTMs have not support yet!')
        elif self.cell_type == 'GRU':
            fw = []
            bw = []
            for i in range(len(self.hidden_size)):
                fw.append(encoder_output_fw[i][:, -1, :])
                bw.append(encoder_output_bw[i][:, -1, :])
            encoder_state_fw = tf.concat(fw, axis=1)
            encoder_state_bw = tf.concat(bw, axis=1)

            # encoder_state has shape [batch_size, sum(hidden_size)*2]
            encoder_state = tf.concat([encoder_state_fw, encoder_state_bw], axis=1)

        decoder_outputs, _ = tf.contrib.legacy_seq2seq.rnn_decoder(decoder_inputs=decoder_inputs,
                                                                   initial_state=encoder_state, cell=cell,
                                                                   loop_function=lf)

        if self.cell_type == 'LSTM':
            hidden_abstract = encoder_state.h
        elif self.cell_type == 'GRU':
            hidden_abstract = encoder_state

        # F_update
        F_update = tf.assign(F, F_new_value)

        real_hidden_abstract = tf.split(hidden_abstract, 2)[0]

        # W has shape [sum(hidden_size)*2, batch_size]
        W = tf.transpose(real_hidden_abstract)
        WTW = tf.matmul(real_hidden_abstract, W)
        FTWTWF = tf.matmul(tf.matmul(tf.transpose(F), WTW), F)

        with tf.name_scope("loss_reconstruct"):
            loss_reconstruct = tf.losses.mean_squared_error(labels=tf.split(targets, 2, axis=1)[0],
                                                            predictions=tf.split(decoder_outputs, 2, axis=1)[0])

        with tf.name_scope("k-means_loss"):
            loss_k_means = tf.trace(WTW) - tf.trace(FTWTWF)

        with tf.name_scope("discriminative_loss"):
            weight1 = weight_variable(shape=[hidden_abstract.get_shape().as_list()[1], 128])
            bias1 = bias_variable(shape=[128])

            weight2 = weight_variable(shape=[128, 2])
            bias2 = bias_variable(shape=[2])

            hidden = tf.nn.relu(tf.matmul(hidden_abstract, weight1) + bias1)
            output = tf.matmul(hidden, weight2) + bias2
            predict = tf.reshape(output, shape=[-1, 2])
            discriminative_loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=predict, labels=real_fake_label))

        with tf.name_scope("loss_total"):
            loss = loss_reconstruct + self.lamda / 2 * loss_k_means + discriminative_loss

        regularization_loss = 0.0
        for i in range(len(tf.trainable_variables())):
            regularization_loss += tf.nn.l2_loss(tf.trainable_variables()[i])
        loss = loss + 1e-4 * regularization_loss
        input_tensors = {
            'inputs': input,
            'noise': noise,
            'F_new_value': F_new_value,
            'real_fake_label': real_fake_label
        }
        loss_tensors = {
            'loss_reconstruct': loss_reconstruct,
            'loss_k_means': loss_k_means,
            'regularization_loss': regularization_loss,
            'discriminative_loss': discriminative_loss,
            'loss': loss
        }
        output_tensor = {'prediction': predict}
        return input_tensors, loss_tensors, real_hidden_abstract, F_update, output_tensor
    def __init__(
            self,
            architecture,
            source_seq_len,
            target_seq_len,
            rnn_size,  # hidden recurrent layer size
            num_layers,
            max_gradient_norm,
            batch_size,
            learning_rate,
            learning_rate_decay_factor,
            summaries_dir,
            loss_to_use,
            number_of_actions,
            one_hot=True,
            residual_velocities=False,
            dtype=tf.float32):
        """Create the model.

    Args:
      architecture: [basic, tied] whether to tie the decoder and decoder.
      source_seq_len: lenght of the input sequence.
      target_seq_len: lenght of the target sequence.
      rnn_size: number of units in the rnn.
      num_layers: number of rnns to stack.
      max_gradient_norm: gradients will be clipped to maximally this norm.
      batch_size: the size of the batches used during training;
        the model construction is independent of batch_size, so it can be
        changed after initialization if this is convenient, e.g., for decoding.
      learning_rate: learning rate to start with.
      learning_rate_decay_factor: decay learning rate by this much when needed.
      summaries_dir: where to log progress for tensorboard.
      loss_to_use: [supervised, sampling_based]. Whether to use ground truth in
        each timestep to compute the loss after decoding, or to feed back the
        prediction from the previous time-step.
      number_of_actions: number of classes we have.
      one_hot: whether to use one_hot encoding during train/test (sup models).
      residual_velocities: whether to use a residual connection that models velocities.
      dtype: the data type to use to store internal variables.
    """

        self.HUMAN_SIZE = 54
        self.input_size = self.HUMAN_SIZE + number_of_actions if one_hot else self.HUMAN_SIZE

        print("One hot is ", one_hot)
        print("Input size is %d" % self.input_size)

        # Summary writers for train and test runs
        self.train_writer = tf.summary.FileWriter(
            os.path.join(summaries_dir, 'train'))
        self.test_writer = tf.summary.FileWriter(
            os.path.join(summaries_dir, 'test'))

        self.source_seq_len = source_seq_len
        self.target_seq_len = target_seq_len
        self.rnn_size = rnn_size
        self.batch_size = batch_size
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=dtype)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)

        # === Create the RNN that will keep the state ===
        print('rnn_size = {0}'.format(rnn_size))
        single_cell = tf.contrib.rnn.GRUCell(self.rnn_size)

        if num_layers > 1:
            # Create multiple rnn layers
            cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)
        else:
            cell = single_cell

        # === Transform the inputs ===
        with tf.name_scope("inputs"):

            enc_in = tf.placeholder(
                dtype,
                shape=[None, source_seq_len - 1, self.input_size],
                name="enc_in")
            dec_in = tf.placeholder(
                dtype,
                shape=[None, target_seq_len, self.input_size],
                name="dec_in")
            dec_out = tf.placeholder(
                dtype,
                shape=[None, target_seq_len, self.input_size],
                name="dec_out")

            self.encoder_inputs = enc_in
            self.decoder_inputs = dec_in
            self.decoder_outputs = dec_out

            enc_in = tf.transpose(enc_in, [1, 0, 2])
            dec_in = tf.transpose(dec_in, [1, 0, 2])
            dec_out = tf.transpose(dec_out, [1, 0, 2])

            enc_in = tf.reshape(enc_in, [-1, self.input_size])
            dec_in = tf.reshape(dec_in, [-1, self.input_size])
            dec_out = tf.reshape(dec_out, [-1, self.input_size])

            enc_in = tf.split(enc_in, source_seq_len - 1, axis=0)
            dec_in = tf.split(dec_in, target_seq_len, axis=0)
            dec_out = tf.split(dec_out, target_seq_len, axis=0)

        # === Add space decoder ===
        cell = rnn_cell_extensions.LinearSpaceDecoderWrapper(
            cell, self.input_size)

        # Finally, wrap everything in a residual layer if we want to model velocities
        if residual_velocities:
            cell = rnn_cell_extensions.ResidualWrapper(cell)

        # Store the outputs here
        outputs = []

        # Define the loss function
        lf = None
        if loss_to_use == "sampling_based":

            def lf(prev, i):  # function for sampling_based loss
                return prev
        elif loss_to_use == "supervised":
            pass
        else:
            raise (ValueError, "unknown loss: %s" % loss_to_use)

        # Build the RNN
        if architecture == "basic":
            # Basic RNN does not have a loop function in its API, so copying here.
            with vs.variable_scope("basic_rnn_seq2seq"):
                _, enc_state = tf.contrib.rnn.static_rnn(
                    cell, enc_in, dtype=tf.float32)  # Encoder
                outputs, self.states = tf.contrib.legacy_seq2seq.rnn_decoder(
                    dec_in, enc_state, cell, loop_function=lf)  # Decoder
        elif architecture == "tied":
            outputs, self.states = tf.contrib.legacy_seq2seq.tied_rnn_seq2seq(
                enc_in, dec_in, cell, loop_function=lf)
        else:
            raise (ValueError, "Uknown architecture: %s" % architecture)

        self.outputs = outputs

        with tf.name_scope("loss_angles"):
            loss_angles = tf.reduce_mean(
                tf.square(tf.subtract(dec_out, outputs)))

        self.loss = loss_angles
        self.loss_summary = tf.summary.scalar('loss/loss', self.loss)

        # Gradients and SGD update operation for training the model.
        params = tf.trainable_variables()

        opt = tf.train.GradientDescentOptimizer(self.learning_rate)

        # Update all the trainable parameters
        gradients = tf.gradients(self.loss, params)

        clipped_gradients, norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.gradient_norms = norm
        self.updates = opt.apply_gradients(zip(clipped_gradients, params),
                                           global_step=self.global_step)

        # Keep track of the learning rate
        self.learning_rate_summary = tf.summary.scalar(
            'learning_rate/learning_rate', self.learning_rate)

        # === variables for loss in Euler Angles -- for each action
        with tf.name_scope("euler_error_walking"):
            self.walking_err80 = tf.placeholder(tf.float32,
                                                name="walking_srnn_seeds_0080")
            self.walking_err160 = tf.placeholder(
                tf.float32, name="walking_srnn_seeds_0160")
            self.walking_err320 = tf.placeholder(
                tf.float32, name="walking_srnn_seeds_0320")
            self.walking_err400 = tf.placeholder(
                tf.float32, name="walking_srnn_seeds_0400")
            self.walking_err560 = tf.placeholder(
                tf.float32, name="walking_srnn_seeds_0560")
            self.walking_err1000 = tf.placeholder(
                tf.float32, name="walking_srnn_seeds_1000")

            self.walking_err80_summary = tf.summary.scalar(
                'euler_error_walking/srnn_seeds_0080', self.walking_err80)
            self.walking_err160_summary = tf.summary.scalar(
                'euler_error_walking/srnn_seeds_0160', self.walking_err160)
            self.walking_err320_summary = tf.summary.scalar(
                'euler_error_walking/srnn_seeds_0320', self.walking_err320)
            self.walking_err400_summary = tf.summary.scalar(
                'euler_error_walking/srnn_seeds_0400', self.walking_err400)
            self.walking_err560_summary = tf.summary.scalar(
                'euler_error_walking/srnn_seeds_0560', self.walking_err560)
            self.walking_err1000_summary = tf.summary.scalar(
                'euler_error_walking/srnn_seeds_1000', self.walking_err1000)
        with tf.name_scope("euler_error_eating"):
            self.eating_err80 = tf.placeholder(tf.float32,
                                               name="eating_srnn_seeds_0080")
            self.eating_err160 = tf.placeholder(tf.float32,
                                                name="eating_srnn_seeds_0160")
            self.eating_err320 = tf.placeholder(tf.float32,
                                                name="eating_srnn_seeds_0320")
            self.eating_err400 = tf.placeholder(tf.float32,
                                                name="eating_srnn_seeds_0400")
            self.eating_err560 = tf.placeholder(tf.float32,
                                                name="eating_srnn_seeds_0560")
            self.eating_err1000 = tf.placeholder(tf.float32,
                                                 name="eating_srnn_seeds_1000")

            self.eating_err80_summary = tf.summary.scalar(
                'euler_error_eating/srnn_seeds_0080', self.eating_err80)
            self.eating_err160_summary = tf.summary.scalar(
                'euler_error_eating/srnn_seeds_0160', self.eating_err160)
            self.eating_err320_summary = tf.summary.scalar(
                'euler_error_eating/srnn_seeds_0320', self.eating_err320)
            self.eating_err400_summary = tf.summary.scalar(
                'euler_error_eating/srnn_seeds_0400', self.eating_err400)
            self.eating_err560_summary = tf.summary.scalar(
                'euler_error_eating/srnn_seeds_0560', self.eating_err560)
            self.eating_err1000_summary = tf.summary.scalar(
                'euler_error_eating/srnn_seeds_1000', self.eating_err1000)
        with tf.name_scope("euler_error_smoking"):
            self.smoking_err80 = tf.placeholder(tf.float32,
                                                name="smoking_srnn_seeds_0080")
            self.smoking_err160 = tf.placeholder(
                tf.float32, name="smoking_srnn_seeds_0160")
            self.smoking_err320 = tf.placeholder(
                tf.float32, name="smoking_srnn_seeds_0320")
            self.smoking_err400 = tf.placeholder(
                tf.float32, name="smoking_srnn_seeds_0400")
            self.smoking_err560 = tf.placeholder(
                tf.float32, name="smoking_srnn_seeds_0560")
            self.smoking_err1000 = tf.placeholder(
                tf.float32, name="smoking_srnn_seeds_1000")

            self.smoking_err80_summary = tf.summary.scalar(
                'euler_error_smoking/srnn_seeds_0080', self.smoking_err80)
            self.smoking_err160_summary = tf.summary.scalar(
                'euler_error_smoking/srnn_seeds_0160', self.smoking_err160)
            self.smoking_err320_summary = tf.summary.scalar(
                'euler_error_smoking/srnn_seeds_0320', self.smoking_err320)
            self.smoking_err400_summary = tf.summary.scalar(
                'euler_error_smoking/srnn_seeds_0400', self.smoking_err400)
            self.smoking_err560_summary = tf.summary.scalar(
                'euler_error_smoking/srnn_seeds_0560', self.smoking_err560)
            self.smoking_err1000_summary = tf.summary.scalar(
                'euler_error_smoking/srnn_seeds_1000', self.smoking_err1000)
        with tf.name_scope("euler_error_discussion"):
            self.discussion_err80 = tf.placeholder(
                tf.float32, name="discussion_srnn_seeds_0080")
            self.discussion_err160 = tf.placeholder(
                tf.float32, name="discussion_srnn_seeds_0160")
            self.discussion_err320 = tf.placeholder(
                tf.float32, name="discussion_srnn_seeds_0320")
            self.discussion_err400 = tf.placeholder(
                tf.float32, name="discussion_srnn_seeds_0400")
            self.discussion_err560 = tf.placeholder(
                tf.float32, name="discussion_srnn_seeds_0560")
            self.discussion_err1000 = tf.placeholder(
                tf.float32, name="discussion_srnn_seeds_1000")

            self.discussion_err80_summary = tf.summary.scalar(
                'euler_error_discussion/srnn_seeds_0080',
                self.discussion_err80)
            self.discussion_err160_summary = tf.summary.scalar(
                'euler_error_discussion/srnn_seeds_0160',
                self.discussion_err160)
            self.discussion_err320_summary = tf.summary.scalar(
                'euler_error_discussion/srnn_seeds_0320',
                self.discussion_err320)
            self.discussion_err400_summary = tf.summary.scalar(
                'euler_error_discussion/srnn_seeds_0400',
                self.discussion_err400)
            self.discussion_err560_summary = tf.summary.scalar(
                'euler_error_discussion/srnn_seeds_0560',
                self.discussion_err560)
            self.discussion_err1000_summary = tf.summary.scalar(
                'euler_error_discussion/srnn_seeds_1000',
                self.discussion_err1000)
        with tf.name_scope("euler_error_directions"):
            self.directions_err80 = tf.placeholder(
                tf.float32, name="directions_srnn_seeds_0080")
            self.directions_err160 = tf.placeholder(
                tf.float32, name="directions_srnn_seeds_0160")
            self.directions_err320 = tf.placeholder(
                tf.float32, name="directions_srnn_seeds_0320")
            self.directions_err400 = tf.placeholder(
                tf.float32, name="directions_srnn_seeds_0400")
            self.directions_err560 = tf.placeholder(
                tf.float32, name="directions_srnn_seeds_0560")
            self.directions_err1000 = tf.placeholder(
                tf.float32, name="directions_srnn_seeds_1000")

            self.directions_err80_summary = tf.summary.scalar(
                'euler_error_directions/srnn_seeds_0080',
                self.directions_err80)
            self.directions_err160_summary = tf.summary.scalar(
                'euler_error_directions/srnn_seeds_0160',
                self.directions_err160)
            self.directions_err320_summary = tf.summary.scalar(
                'euler_error_directions/srnn_seeds_0320',
                self.directions_err320)
            self.directions_err400_summary = tf.summary.scalar(
                'euler_error_directions/srnn_seeds_0400',
                self.directions_err400)
            self.directions_err560_summary = tf.summary.scalar(
                'euler_error_directions/srnn_seeds_0560',
                self.directions_err560)
            self.directions_err1000_summary = tf.summary.scalar(
                'euler_error_directions/srnn_seeds_1000',
                self.directions_err1000)
        with tf.name_scope("euler_error_greeting"):
            self.greeting_err80 = tf.placeholder(
                tf.float32, name="greeting_srnn_seeds_0080")
            self.greeting_err160 = tf.placeholder(
                tf.float32, name="greeting_srnn_seeds_0160")
            self.greeting_err320 = tf.placeholder(
                tf.float32, name="greeting_srnn_seeds_0320")
            self.greeting_err400 = tf.placeholder(
                tf.float32, name="greeting_srnn_seeds_0400")
            self.greeting_err560 = tf.placeholder(
                tf.float32, name="greeting_srnn_seeds_0560")
            self.greeting_err1000 = tf.placeholder(
                tf.float32, name="greeting_srnn_seeds_1000")

            self.greeting_err80_summary = tf.summary.scalar(
                'euler_error_greeting/srnn_seeds_0080', self.greeting_err80)
            self.greeting_err160_summary = tf.summary.scalar(
                'euler_error_greeting/srnn_seeds_0160', self.greeting_err160)
            self.greeting_err320_summary = tf.summary.scalar(
                'euler_error_greeting/srnn_seeds_0320', self.greeting_err320)
            self.greeting_err400_summary = tf.summary.scalar(
                'euler_error_greeting/srnn_seeds_0400', self.greeting_err400)
            self.greeting_err560_summary = tf.summary.scalar(
                'euler_error_greeting/srnn_seeds_0560', self.greeting_err560)
            self.greeting_err1000_summary = tf.summary.scalar(
                'euler_error_greeting/srnn_seeds_1000', self.greeting_err1000)
        with tf.name_scope("euler_error_phoning"):
            self.phoning_err80 = tf.placeholder(tf.float32,
                                                name="phoning_srnn_seeds_0080")
            self.phoning_err160 = tf.placeholder(
                tf.float32, name="phoning_srnn_seeds_0160")
            self.phoning_err320 = tf.placeholder(
                tf.float32, name="phoning_srnn_seeds_0320")
            self.phoning_err400 = tf.placeholder(
                tf.float32, name="phoning_srnn_seeds_0400")
            self.phoning_err560 = tf.placeholder(
                tf.float32, name="phoning_srnn_seeds_0560")
            self.phoning_err1000 = tf.placeholder(
                tf.float32, name="phoning_srnn_seeds_1000")

            self.phoning_err80_summary = tf.summary.scalar(
                'euler_error_phoning/srnn_seeds_0080', self.phoning_err80)
            self.phoning_err160_summary = tf.summary.scalar(
                'euler_error_phoning/srnn_seeds_0160', self.phoning_err160)
            self.phoning_err320_summary = tf.summary.scalar(
                'euler_error_phoning/srnn_seeds_0320', self.phoning_err320)
            self.phoning_err400_summary = tf.summary.scalar(
                'euler_error_phoning/srnn_seeds_0400', self.phoning_err400)
            self.phoning_err560_summary = tf.summary.scalar(
                'euler_error_phoning/srnn_seeds_0560', self.phoning_err560)
            self.phoning_err1000_summary = tf.summary.scalar(
                'euler_error_phoning/srnn_seeds_1000', self.phoning_err1000)
        with tf.name_scope("euler_error_posing"):
            self.posing_err80 = tf.placeholder(tf.float32,
                                               name="posing_srnn_seeds_0080")
            self.posing_err160 = tf.placeholder(tf.float32,
                                                name="posing_srnn_seeds_0160")
            self.posing_err320 = tf.placeholder(tf.float32,
                                                name="posing_srnn_seeds_0320")
            self.posing_err400 = tf.placeholder(tf.float32,
                                                name="posing_srnn_seeds_0400")
            self.posing_err560 = tf.placeholder(tf.float32,
                                                name="posing_srnn_seeds_0560")
            self.posing_err1000 = tf.placeholder(tf.float32,
                                                 name="posing_srnn_seeds_1000")

            self.posing_err80_summary = tf.summary.scalar(
                'euler_error_posing/srnn_seeds_0080', self.posing_err80)
            self.posing_err160_summary = tf.summary.scalar(
                'euler_error_posing/srnn_seeds_0160', self.posing_err160)
            self.posing_err320_summary = tf.summary.scalar(
                'euler_error_posing/srnn_seeds_0320', self.posing_err320)
            self.posing_err400_summary = tf.summary.scalar(
                'euler_error_posing/srnn_seeds_0400', self.posing_err400)
            self.posing_err560_summary = tf.summary.scalar(
                'euler_error_posing/srnn_seeds_0560', self.posing_err560)
            self.posing_err1000_summary = tf.summary.scalar(
                'euler_error_posing/srnn_seeds_1000', self.posing_err1000)
        with tf.name_scope("euler_error_purchases"):
            self.purchases_err80 = tf.placeholder(
                tf.float32, name="purchases_srnn_seeds_0080")
            self.purchases_err160 = tf.placeholder(
                tf.float32, name="purchases_srnn_seeds_0160")
            self.purchases_err320 = tf.placeholder(
                tf.float32, name="purchases_srnn_seeds_0320")
            self.purchases_err400 = tf.placeholder(
                tf.float32, name="purchases_srnn_seeds_0400")
            self.purchases_err560 = tf.placeholder(
                tf.float32, name="purchases_srnn_seeds_0560")
            self.purchases_err1000 = tf.placeholder(
                tf.float32, name="purchases_srnn_seeds_1000")

            self.purchases_err80_summary = tf.summary.scalar(
                'euler_error_purchases/srnn_seeds_0080', self.purchases_err80)
            self.purchases_err160_summary = tf.summary.scalar(
                'euler_error_purchases/srnn_seeds_0160', self.purchases_err160)
            self.purchases_err320_summary = tf.summary.scalar(
                'euler_error_purchases/srnn_seeds_0320', self.purchases_err320)
            self.purchases_err400_summary = tf.summary.scalar(
                'euler_error_purchases/srnn_seeds_0400', self.purchases_err400)
            self.purchases_err560_summary = tf.summary.scalar(
                'euler_error_purchases/srnn_seeds_0560', self.purchases_err560)
            self.purchases_err1000_summary = tf.summary.scalar(
                'euler_error_purchases/srnn_seeds_1000',
                self.purchases_err1000)
        with tf.name_scope("euler_error_sitting"):
            self.sitting_err80 = tf.placeholder(tf.float32,
                                                name="sitting_srnn_seeds_0080")
            self.sitting_err160 = tf.placeholder(
                tf.float32, name="sitting_srnn_seeds_0160")
            self.sitting_err320 = tf.placeholder(
                tf.float32, name="sitting_srnn_seeds_0320")
            self.sitting_err400 = tf.placeholder(
                tf.float32, name="sitting_srnn_seeds_0400")
            self.sitting_err560 = tf.placeholder(
                tf.float32, name="sitting_srnn_seeds_0560")
            self.sitting_err1000 = tf.placeholder(
                tf.float32, name="sitting_srnn_seeds_1000")

            self.sitting_err80_summary = tf.summary.scalar(
                'euler_error_sitting/srnn_seeds_0080', self.sitting_err80)
            self.sitting_err160_summary = tf.summary.scalar(
                'euler_error_sitting/srnn_seeds_0160', self.sitting_err160)
            self.sitting_err320_summary = tf.summary.scalar(
                'euler_error_sitting/srnn_seeds_0320', self.sitting_err320)
            self.sitting_err400_summary = tf.summary.scalar(
                'euler_error_sitting/srnn_seeds_0400', self.sitting_err400)
            self.sitting_err560_summary = tf.summary.scalar(
                'euler_error_sitting/srnn_seeds_0560', self.sitting_err560)
            self.sitting_err1000_summary = tf.summary.scalar(
                'euler_error_sitting/srnn_seeds_1000', self.sitting_err1000)
        with tf.name_scope("euler_error_sittingdown"):
            self.sittingdown_err80 = tf.placeholder(
                tf.float32, name="sittingdown_srnn_seeds_0080")
            self.sittingdown_err160 = tf.placeholder(
                tf.float32, name="sittingdown_srnn_seeds_0160")
            self.sittingdown_err320 = tf.placeholder(
                tf.float32, name="sittingdown_srnn_seeds_0320")
            self.sittingdown_err400 = tf.placeholder(
                tf.float32, name="sittingdown_srnn_seeds_0400")
            self.sittingdown_err560 = tf.placeholder(
                tf.float32, name="sittingdown_srnn_seeds_0560")
            self.sittingdown_err1000 = tf.placeholder(
                tf.float32, name="sittingdown_srnn_seeds_1000")

            self.sittingdown_err80_summary = tf.summary.scalar(
                'euler_error_sittingdown/srnn_seeds_0080',
                self.sittingdown_err80)
            self.sittingdown_err160_summary = tf.summary.scalar(
                'euler_error_sittingdown/srnn_seeds_0160',
                self.sittingdown_err160)
            self.sittingdown_err320_summary = tf.summary.scalar(
                'euler_error_sittingdown/srnn_seeds_0320',
                self.sittingdown_err320)
            self.sittingdown_err400_summary = tf.summary.scalar(
                'euler_error_sittingdown/srnn_seeds_0400',
                self.sittingdown_err400)
            self.sittingdown_err560_summary = tf.summary.scalar(
                'euler_error_sittingdown/srnn_seeds_0560',
                self.sittingdown_err560)
            self.sittingdown_err1000_summary = tf.summary.scalar(
                'euler_error_sittingdown/srnn_seeds_1000',
                self.sittingdown_err1000)
        with tf.name_scope("euler_error_takingphoto"):
            self.takingphoto_err80 = tf.placeholder(
                tf.float32, name="takingphoto_srnn_seeds_0080")
            self.takingphoto_err160 = tf.placeholder(
                tf.float32, name="takingphoto_srnn_seeds_0160")
            self.takingphoto_err320 = tf.placeholder(
                tf.float32, name="takingphoto_srnn_seeds_0320")
            self.takingphoto_err400 = tf.placeholder(
                tf.float32, name="takingphoto_srnn_seeds_0400")
            self.takingphoto_err560 = tf.placeholder(
                tf.float32, name="takingphoto_srnn_seeds_0560")
            self.takingphoto_err1000 = tf.placeholder(
                tf.float32, name="takingphoto_srnn_seeds_1000")

            self.takingphoto_err80_summary = tf.summary.scalar(
                'euler_error_takingphoto/srnn_seeds_0080',
                self.takingphoto_err80)
            self.takingphoto_err160_summary = tf.summary.scalar(
                'euler_error_takingphoto/srnn_seeds_0160',
                self.takingphoto_err160)
            self.takingphoto_err320_summary = tf.summary.scalar(
                'euler_error_takingphoto/srnn_seeds_0320',
                self.takingphoto_err320)
            self.takingphoto_err400_summary = tf.summary.scalar(
                'euler_error_takingphoto/srnn_seeds_0400',
                self.takingphoto_err400)
            self.takingphoto_err560_summary = tf.summary.scalar(
                'euler_error_takingphoto/srnn_seeds_0560',
                self.takingphoto_err560)
            self.takingphoto_err1000_summary = tf.summary.scalar(
                'euler_error_takingphoto/srnn_seeds_1000',
                self.takingphoto_err1000)
        with tf.name_scope("euler_error_waiting"):
            self.waiting_err80 = tf.placeholder(tf.float32,
                                                name="waiting_srnn_seeds_0080")
            self.waiting_err160 = tf.placeholder(
                tf.float32, name="waiting_srnn_seeds_0160")
            self.waiting_err320 = tf.placeholder(
                tf.float32, name="waiting_srnn_seeds_0320")
            self.waiting_err400 = tf.placeholder(
                tf.float32, name="waiting_srnn_seeds_0400")
            self.waiting_err560 = tf.placeholder(
                tf.float32, name="waiting_srnn_seeds_0560")
            self.waiting_err1000 = tf.placeholder(
                tf.float32, name="waiting_srnn_seeds_1000")

            self.waiting_err80_summary = tf.summary.scalar(
                'euler_error_waiting/srnn_seeds_0080', self.waiting_err80)
            self.waiting_err160_summary = tf.summary.scalar(
                'euler_error_waiting/srnn_seeds_0160', self.waiting_err160)
            self.waiting_err320_summary = tf.summary.scalar(
                'euler_error_waiting/srnn_seeds_0320', self.waiting_err320)
            self.waiting_err400_summary = tf.summary.scalar(
                'euler_error_waiting/srnn_seeds_0400', self.waiting_err400)
            self.waiting_err560_summary = tf.summary.scalar(
                'euler_error_waiting/srnn_seeds_0560', self.waiting_err560)
            self.waiting_err1000_summary = tf.summary.scalar(
                'euler_error_waiting/srnn_seeds_1000', self.waiting_err1000)
        with tf.name_scope("euler_error_walkingdog"):
            self.walkingdog_err80 = tf.placeholder(
                tf.float32, name="walkingdog_srnn_seeds_0080")
            self.walkingdog_err160 = tf.placeholder(
                tf.float32, name="walkingdog_srnn_seeds_0160")
            self.walkingdog_err320 = tf.placeholder(
                tf.float32, name="walkingdog_srnn_seeds_0320")
            self.walkingdog_err400 = tf.placeholder(
                tf.float32, name="walkingdog_srnn_seeds_0400")
            self.walkingdog_err560 = tf.placeholder(
                tf.float32, name="walkingdog_srnn_seeds_0560")
            self.walkingdog_err1000 = tf.placeholder(
                tf.float32, name="walkingdog_srnn_seeds_1000")

            self.walkingdog_err80_summary = tf.summary.scalar(
                'euler_error_walkingdog/srnn_seeds_0080',
                self.walkingdog_err80)
            self.walkingdog_err160_summary = tf.summary.scalar(
                'euler_error_walkingdog/srnn_seeds_0160',
                self.walkingdog_err160)
            self.walkingdog_err320_summary = tf.summary.scalar(
                'euler_error_walkingdog/srnn_seeds_0320',
                self.walkingdog_err320)
            self.walkingdog_err400_summary = tf.summary.scalar(
                'euler_error_walkingdog/srnn_seeds_0400',
                self.walkingdog_err400)
            self.walkingdog_err560_summary = tf.summary.scalar(
                'euler_error_walkingdog/srnn_seeds_0560',
                self.walkingdog_err560)
            self.walkingdog_err1000_summary = tf.summary.scalar(
                'euler_error_walkingdog/srnn_seeds_1000',
                self.walkingdog_err1000)
        with tf.name_scope("euler_error_walkingtogether"):
            self.walkingtogether_err80 = tf.placeholder(
                tf.float32, name="walkingtogether_srnn_seeds_0080")
            self.walkingtogether_err160 = tf.placeholder(
                tf.float32, name="walkingtogether_srnn_seeds_0160")
            self.walkingtogether_err320 = tf.placeholder(
                tf.float32, name="walkingtogether_srnn_seeds_0320")
            self.walkingtogether_err400 = tf.placeholder(
                tf.float32, name="walkingtogether_srnn_seeds_0400")
            self.walkingtogether_err560 = tf.placeholder(
                tf.float32, name="walkingtogether_srnn_seeds_0560")
            self.walkingtogether_err1000 = tf.placeholder(
                tf.float32, name="walkingtogether_srnn_seeds_1000")

            self.walkingtogether_err80_summary = tf.summary.scalar(
                'euler_error_walkingtogether/srnn_seeds_0080',
                self.walkingtogether_err80)
            self.walkingtogether_err160_summary = tf.summary.scalar(
                'euler_error_walkingtogether/srnn_seeds_0160',
                self.walkingtogether_err160)
            self.walkingtogether_err320_summary = tf.summary.scalar(
                'euler_error_walkingtogether/srnn_seeds_0320',
                self.walkingtogether_err320)
            self.walkingtogether_err400_summary = tf.summary.scalar(
                'euler_error_walkingtogether/srnn_seeds_0400',
                self.walkingtogether_err400)
            self.walkingtogether_err560_summary = tf.summary.scalar(
                'euler_error_walkingtogether/srnn_seeds_0560',
                self.walkingtogether_err560)
            self.walkingtogether_err1000_summary = tf.summary.scalar(
                'euler_error_walkingtogether/srnn_seeds_1000',
                self.walkingtogether_err1000)

        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
    def __init__(
            self,
            architecture,
            max_seq_len,
            human_size,
            rnn_size,  # hidden recurrent layer size
            num_layers,
            max_gradient_norm,
            stddev,
            batch_size,
            learning_rate,
            learning_rate_decay_factor,
            summaries_dir,
            loss_to_use,
            number_of_actions,
            one_hot=True,
            residual_velocities=False,
            dtype=tf.float32):
        """Create the model.

    Args:
      architecture: [basic, tied] whether to tie the decoder and decoder.
      source_seq_len: lenght of the input sequence.
      #target_seq_len: lenght of the target sequence.
      rnn_size: number of units in the rnn.
      num_layers: number of rnns to stack.
      max_gradient_norm: gradients will be clipped to maximally this norm.
      batch_size: the size of the batches used during training;
        the model construction is independent of batch_size, so it can be
        changed after initialization if this is convenient, e.g., for decoding.
      learning_rate: learning rate to start with.
      learning_rate_decay_factor: decay learning rate by this much when needed.
      summaries_dir: where to log progress for tensorboard.
      loss_to_use: [supervised, sampling_based]. Whether to use ground truth in
        each timestep to compute the loss after decoding, or to feed back the
        prediction from the previous time-step.
      number_of_actions: number of classes we have.
      one_hot: whether to use one_hot encoding during train/test (sup models).
      residual_velocities: whether to use a residual connection that models velocities.
      dtype: the data type to use to store internal variables.
    """

        self.HUMAN_SIZE = human_size
        self.input_size = self.HUMAN_SIZE + number_of_actions if one_hot else self.HUMAN_SIZE

        print("One hot is ", one_hot)
        print("Input size is %d" % self.input_size)

        # Summary writers for train and test runs
        self.train_writer = tf.summary.FileWriter(
            os.path.normpath(os.path.join(summaries_dir, 'train')))
        self.test_writer = tf.summary.FileWriter(
            os.path.normpath(os.path.join(summaries_dir, 'test')))

        self.max_seq_len = max_seq_len
        self.rnn_size = rnn_size
        self.batch_size = batch_size
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=dtype)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)

        # === Create the RNN that will keep the state ===
        print('rnn_size = {0}'.format(rnn_size))
        cell = tf.contrib.rnn.GRUCell(self.rnn_size)

        if num_layers > 1:
            cell = tf.contrib.rnn.MultiRNNCell([
                tf.contrib.rnn.GRUCell(self.rnn_size)
                for _ in range(num_layers)
            ])

        # === Transform the inputs ===
        with tf.name_scope("inputs_gts"):

            inputs = tf.placeholder(
                dtype,
                shape=[None, self.max_seq_len + 1, self.input_size],
                name="inputs")
            gts = tf.placeholder(
                dtype,
                shape=[None, self.max_seq_len, self.input_size],
                name="gts")
            seq_len = tf.placeholder(tf.int32, shape=[None], name="seq_len")

            self.inputs = inputs
            self.gts = gts
            self.seq_len = seq_len
            '''
      inputs = tf.transpose(inputs, [1, 0, 2])
      gts    = tf.transpose(gts, [1, 0, 2])

      inputs = tf.reshape(inputs, [-1, self.input_size])
      gts    = tf.reshape(gts,    [-1, self.input_size])

      inputs = tf.split(inputs, self.max_seq_len, axis=0)
      gts    = tf.split(gts,    self.max_seq_len, axis=0)
      '''
            inputs = _transpose_batch_time(inputs)
            gts = _transpose_batch_time(gts)

        # === Add space decoder ===
        cell = rnn_cell_extensions.LinearSpaceDecoderWrapper(
            cell, self.input_size)

        # Finally, wrap everything in a residual layer if we want to model velocities
        if residual_velocities:
            cell = rnn_cell_extensions.ResidualWrapper(cell)

        # Store the outputs here
        outputs = []

        self.stddev = stddev

        def addGN(inputs):
            noise = tf.random_normal(shape=tf.shape(inputs),
                                     mean=0.0,
                                     stddev=self.stddev,
                                     dtype=tf.float32)
            return inputs + noise

        self.is_training = tf.placeholder(dtype=tf.bool)

        # Build the RNN
        if architecture == "basic":
            cell_init_state = tf.Variable(np.zeros([1, cell.state_size]),
                                          trainable=True,
                                          dtype=tf.float32)
            init_input = tf.Variable(np.zeros([63]),
                                     trainable=True,
                                     dtype=tf.float32)
            output_ta = tf.TensorArray(size=self.max_seq_len, dtype=tf.float32)

            def loop_fn(time, cell_output, cell_state, loop_state):
                emit_output = cell_output
                if cell_output is None:
                    #next_cell_state = cell.zero_state(self.batch_size, tf.float32)
                    next_cell_state = tf.tile(cell_init_state,
                                              [tf.shape(inputs[0])[0], 1])
                    next_input = tf.cond(
                        self.is_training, lambda: tf.concat([
                            tf.tile(tf.expand_dims(init_input, 0),
                                    [tf.shape(inputs[0])[0], 1]),
                            addGN(inputs[time])
                        ],
                                                            axis=1),
                        lambda: tf.concat([
                            tf.tile(tf.expand_dims(init_input, 0),
                                    [tf.shape(inputs[0])[0], 1]), inputs[time]
                        ],
                                          axis=1))
                    next_loop_state = output_ta
                else:
                    next_cell_state = cell_state
                    next_input = tf.cond(
                        self.is_training, lambda: tf.concat(
                            [cell_output, addGN(inputs[time])], axis=1),
                        lambda: tf.concat([cell_output, inputs[time]], axis=1))

                    next_loop_state = loop_state.write(time - 1, cell_output)

                finished = (time > self.max_seq_len - 1)
                #finished = False
                return (finished, next_input, next_cell_state, emit_output,
                        next_loop_state)

            # Basic RNN does not have a loop function in its API, so copying here.
            with vs.variable_scope("raw_rnn"):
                _, _, loop_state_ta = tf.nn.raw_rnn(cell, loop_fn)
                #outputs = _transpose_batch_time(loop_state_ta.stack())
                outputs = loop_state_ta.stack()

        self.outputs = outputs
        mask1 = tf.tile(
            tf.expand_dims(
                tf.transpose(
                    tf.sequence_mask(self.seq_len,
                                     dtype=tf.float32,
                                     maxlen=self.max_seq_len)), -1),
            [1, 1, self.input_size])
        mask2 = tf.tile(
            tf.expand_dims(
                tf.transpose(
                    tf.sequence_mask(self.seq_len - 1,
                                     dtype=tf.float32,
                                     maxlen=self.max_seq_len - 1)), -1),
            [1, 1, self.input_size])
        with tf.name_scope("loss_pos"):
            loss_pos = tf.reduce_mean(
                tf.square(
                    tf.subtract(tf.multiply(outputs, mask1),
                                tf.multiply(gts, mask1))))
        with tf.name_scope("loss_smooth"):
            loss_smooth = tf.reduce_mean(
                tf.square(
                    tf.multiply(tf.subtract(outputs[1:], outputs[:-1]),
                                mask2)))
        #self.loss         = tf.add(loss_pos, loss_smooth*1000)
        self.loss = loss_pos
        self.loss_summary = tf.summary.scalar('loss/loss', self.loss)

        self.loss_each_data = tf.reduce_mean(tf.square(tf.subtract(tf.multiply(gts,mask1),
                                                                   tf.multiply(outputs,mask1))),
                                             axis=[0,2]) \
                              + tf.reduce_mean(tf.square(tf.multiply(tf.subtract(
                              outputs[1:], outputs[:-1]),mask2)),axis=[0,2])
        # Gradients and SGD update operation for training the model.
        params = tf.trainable_variables()

        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)

        # Update all the trainable parameters
        gradients = tf.gradients(self.loss, params)

        clipped_gradients, norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.gradient_norms = norm
        self.updates = opt.apply_gradients(zip(clipped_gradients, params),
                                           global_step=self.global_step)

        self.learning_rate_summary = tf.summary.scalar(
            'learning_rate/learning_rate', self.learning_rate)

        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)