class LSTM(NoiseAwareModel): """Long Short-Term Memory.""" def __init__(self): self.lstm = None self.w = None def train(self, training_candidates, training_marginals, **hyperparams): self.lstm = LSTMModel(training_candidates, training_marginals) self.lstm.train(**hyperparams) def marginals(self, test_candidates): return self.lstm.test(test_candidates) def save(self, session, param_set_name): raise NotImplementedError()
def main(_): if FLAGS.seed != -1: np.random.seed(FLAGS.seed) tf.set_random_seed(FLAGS.seed) print("Creating model...") model = LSTMModel(FLAGS.input_size, FLAGS.output_size, FLAGS.num_layers, FLAGS.num_units, FLAGS.direction, FLAGS.learning_rate, FLAGS.dropout, FLAGS.seed, is_training=FLAGS.action == 0, model=FLAGS.model) if FLAGS.action == 0: # TRAINING assert FLAGS.model != 1 data_storage = data_processing.prepare_all_data(FLAGS.time_len) model.train(data_storage, FLAGS.batch_size, FLAGS.num_epochs) # inputs, labels, inputs_valid, labels_valid = data_processing.prepare_data(FLAGS.time_len) # model.train(inputs, inputs_valid, labels, labels_valid, # FLAGS.batch_size, FLAGS.num_epochs) elif FLAGS.action == 1: # EVALUATING assert FLAGS.model == 1 or FLAGS.model == 2, \ "main(): evaluated model must be LSTMBlockCell or LSTMCell" set_idx, aug_idx = 7, 4 inputs, labels = data_processing.prepare_indexed_data( None, set_idx, aug_idx) model.predict(inputs, labels) # pkl_index = 7 # world_data, mocap_data = data_processing.load_data( # open("processed_data/data-" + str(pkl_index) + ".pkl", "rb")) # inputs = data_processing.sequence_split_data(world_data, world_data.shape[0]) # labels = data_processing.sequence_split_data(mocap_data, mocap_data.shape[0]) # #inputs, labels = data_processing.augment_data(inputs, labels) # model.predict(inputs, labels) elif FLAGS.action == 2: # EXPORTING assert FLAGS.model == 1 or FLAGS.model == 2, \ "main(): evaluated model must be LSTMBlockCell or LSTMCell" if FLAGS.model == 1: model.export_weights() elif FLAGS.model == 2: model.export() elif FLAGS.action == 3: # PREDICTING assert FLAGS.model == 1 or FLAGS.model == 2, \ "main(): predicting model must be LSTMBlockCell or LSTMCell" set_idx, aug_idx = 0, 0 inputs, labels = data_processing.prepare_indexed_data( None, set_idx, aug_idx) loss, logits = model.predict(inputs, labels) predicts = data_processing.convert_logits_to_predicts(logits) org_labels = data_processing.convert_logits_to_predicts(labels) print(predicts[:, :3], org_labels[:, :3]) mae = np.mean(np.abs(predicts - org_labels), axis=0) line = "" for i in range(22): line += "{0:.5f}".format(np.mean(mae[3 * i:3 * i + 3])) + " " print(line) print(np.mean(mae[3:]))