def train(self,
              save_path,
              save_step=5000,
              var=0.00001,
              training_step=1000000,
              normalization=True,
              normalization_type='z_score',
              decay='False',
              load_path=None):
        """
        :param save_path: where to save the weighs and bias as well as normalization parameters
        :param save_step: save model per 500000(default) steps
        :param var: the variance of bayesian nn output, should be trainable(todo)
        :param training_step: maximum training steps
        :param normalization: if normalize data before training
        :param normalization_type: choose 'min_max' or 'z_score' normalization
        :param decay: if decay learning rate while training
        :return:
        """
        if normalization:
            if normalization_type == 'min_max':
                x_min_arr = np.amin(self.x_data, axis=0)
                x_max_arr = np.amax(self.x_data, axis=0)
                y_min_arr = np.amin(self.y_data, axis=0)
                y_max_arr = np.amax(self.y_data, axis=0)
                self.x_data = min_max_normalize(self.x_data, x_min_arr,
                                                x_max_arr)
                self.y_data = min_max_normalize(self.y_data, y_min_arr,
                                                y_max_arr)
                with open(save_path + '/normalization_arr/normalization_arr',
                          'wb') as pickle_file:
                    pickle.dump(
                        ((x_min_arr, x_max_arr), (y_min_arr, y_max_arr)),
                        pickle_file)
                # with open(save_path+'/normalization_arr/y_normalization_arr', 'wb') as pickle_file:
                #     pickle.dump((y_min_arr, y_max_arr), pickle_file)
            elif normalization_type == 'z_score':
                x_mean_arr = np.mean(self.x_data, axis=0)
                x_std_arr = np.std(self.x_data, axis=0)
                y_mean_arr = np.mean(self.y_data, axis=0)
                y_std_arr = np.std(self.y_data, axis=0)
                self.x_data = z_score_normalize(self.x_data, x_mean_arr,
                                                x_std_arr)
                self.y_data = z_score_normalize(self.y_data, y_mean_arr,
                                                y_std_arr)
                with open(save_path + '/normalization_arr/normalization_arr',
                          'wb') as pickle_file:
                    pickle.dump(
                        ((x_mean_arr, x_std_arr), (y_mean_arr, y_std_arr)),
                        pickle_file)
                # with open(save_path+'/normalization_arr/y_normalization_arr', 'wb') as pickle_file:
                #     pickle.dump((y_mean_arr, y_std_arr), pickle_file)

        self.var = [var for i in range(self.y_data.shape[1])
                    ]  # the variance for bayesian neural network output
        (xs, ys, handle, training_iterator,
         heldout_iterator) = self.build_input_pipeline()
        #Sandwhich neural net with transfer net
        transformed_state = self.T(
            xs[:, :4])  #Only transform the state, not the action
        state_and_action = tf.concat([transformed_state, xs[:, 4:]], axis=1)
        y_pre = self.neural_net(state_and_action)
        y_pre = self.T_inv(y_pre)  #untransform

        # y_pre = self.neural_net(xs)
        ys_distribution = tfp.distributions.Normal(loc=y_pre, scale=self.var)
        neg_log_likelihood = -tf.reduce_mean(
            input_tensor=ys_distribution.log_prob(ys))
        kl = sum(self.neural_net.losses) / self.batch_size
        elbo_loss = neg_log_likelihood + kl
        predictions = ys_distribution.sample()

        accuracy, accuracy_update_op = tf.metrics.mean_squared_error(
            labels=ys, predictions=predictions)

        with tf.name_scope("train"):
            if decay == 'True':  #Add learning rate decay
                global_step = tf.Variable(0, trainable=False)
                learning_rate = tf.train.exponential_decay(self.lr,
                                                           global_step,
                                                           100000,
                                                           0.965,
                                                           staircase=True)
                optimizer = tf.compat.v1.train.AdamOptimizer(
                    learning_rate=learning_rate)
                train_op = optimizer.minimize(elbo_loss,
                                              global_step=global_step)
                # train_op = elbo_loss
                # train_op = optimizer.minimize(elbo_loss, global_step=global_step, var_list = [])
            else:
                learning_rate = self.lr
                optimizer = tf.compat.v1.train.AdamOptimizer(
                    learning_rate=learning_rate)
                train_op = optimizer.minimize(elbo_loss)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        with tf.Session() as sess:
            sess.run(init_op)
            if load_path:
                print("LOADING WEIGHTS")
                self.neural_net.load_weights(load_path +
                                             '/weights/BNN_weights')
            # sess.graph.finalize()
            # Run the training loop.
            train_handle = sess.run(training_iterator.string_handle())
            heldout_handle = sess.run(heldout_iterator.string_handle())

            for step in range(training_step):
                _, _, ac = sess.run([train_op, accuracy_update_op, accuracy],
                                    feed_dict={handle: train_handle})
                if step % 100 == 0:
                    loss_value, accuracy_value = sess.run(
                        [elbo_loss, accuracy],
                        feed_dict={handle: heldout_handle
                                   })  #Measure accuracy against heldout data
                    print("Step: {:>3d} Loss: {:.3f} Accuracy: {:.5f}".format(
                        step, loss_value, accuracy_value))
                if step % save_step == 0 and step != 0:
                    print("Saving weights")
                    self.neural_net.save_weights(
                        save_path + '/weights/BNN_weights')  #Save weights

        return accuracy_value
y_ang_delta_pre = y_ang_distribution.sample()
y_vel_delta_pre = y_vel_distribution.sample()


if __name__ == "__main__":
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        neural_net_ang.load_weights(ang_model_path+"/weights/BNN_weights")  # load NN parameters
        neural_net_vel.load_weights(vel_model_path+"/weights/BNN_weights")
        angs = []  # prediction in angle space
        vels = []  # prediction in velocity space
        angs.append(validation_data[0][:2])
        vels.append(validation_data[0][2:4])
        state = np.array(validation_data[0])
        nor_state = z_score_normalize(np.asarray([state]), x_nor_arr[0], x_nor_arr[1])
        print(nor_state) #EDIT
        for i in range(len(validation_data)-1):
            (ang_delta, vel_delta) = sess.run((y_ang_delta_pre, y_vel_delta_pre), feed_dict={x: nor_state})
            ang_delta = z_score_denormalize(ang_delta, y_nor_arr_ang[0], y_nor_arr_ang[1])[0]  # denormalize
            vel_delta = z_score_denormalize(vel_delta, y_nor_arr_vel[0], y_nor_arr_vel[1])[0]
            next_ang = state[:2] + ang_delta
            next_vel = state[2:4] + vel_delta
            angs.append(next_ang)
            vels.append(next_vel)
            state = np.append(np.append(next_ang, next_vel), validation_data[i + 1][4:5])
            state = map_angle(state)
            nor_state = z_score_normalize(np.asarray([state]), x_nor_arr[0], x_nor_arr[1])

angs = np.asarray(angs)
vels = np.asarray(vels)