def train(model, optimizer, train_instances, validation_instances, num_epochs, batch_size, serialization_dir): print("\nGenerating train batches") train_batches = generate_batches(train_instances, batch_size) print("\nGenerating val batches") val_batches = generate_batches(validation_instances, batch_size) train_batch_labels = [batch_inputs.pop("labels") for batch_inputs in train_batches] val_batch_labels = [batch_inputs.pop("labels") for batch_inputs in val_batches] for epoch in range(num_epochs): print(f"\nEpoch{epoch}") epoch_loss = 0 generator_tqdm = tqdm(list(zip(train_batches, train_batch_labels))) for batch_inputs, batch_labels in generator_tqdm: with tf.GradientTape() as tape: logits = model(**batch_inputs, training=True)['logits'] loss_val = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=batch_labels) ### TODO(student) START # Calculate l2 regularization over trainable variables with lambda=0.00001 regularization = 0.00001 * tf.add_n([tf.nn.l2_loss(var) for var in model.trainable_variables]) ### TODO(Student) END loss_val += regularization grads = tape.gradient(loss_val, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) loss_val = tf.reduce_mean(loss_val) epoch_loss += loss_val epoch_loss = epoch_loss / len(train_batches) print(f"Train loss for epoch: {epoch_loss}") val_loss = 0 total_preds = [] total_labels = [] generator_tqdm = tqdm(list(zip(val_batches, val_batch_labels))) for batch_inputs, batch_labels in generator_tqdm: logits = model(**batch_inputs, training=False)['logits'] loss_value = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=batch_labels) batch_preds = np.argmax(tf.nn.softmax(logits, axis=-1).numpy(), axis=-1) batch_labels = np.argmax(batch_labels, axis=-1) total_preds.extend(batch_preds) total_labels.extend(batch_labels) val_loss += tf.reduce_mean(loss_value) # remove "Other" class (id = 0) becase we don't care in evaluation non_zero_preds = np.array(list(set(total_preds) - {0})) f1 = f1_score(total_labels, total_preds, labels=non_zero_preds, average='macro') val_loss = val_loss/len(val_batches) print(f"Val loss for epoch: {round(float(val_loss), 4)}") print(f"Val F1 score: {round(float(f1), 4)}") model.save_weights(os.path.join(serialization_dir, f'model.ckpt')) return {'model': model}
def predict(model: models.Model, instances: List[Dict], batch_size: int, save_to_file: str = None) -> List[int]: """ Makes predictions using model on instances and saves them in save_to_file. """ batches = generate_batches(instances, batch_size) predicted_labels = [] all_predicted_labels = [] print("Making predictions") for batch_inputs in tqdm(batches): batch_inputs.pop("labels") logits = model(**batch_inputs, training=False)["logits"] predicted_labels = list(tf.argmax(logits, axis=-1).numpy()) all_predicted_labels += predicted_labels if save_to_file: print(f"Saving predictions to filepath: {save_to_file}") with open(save_to_file, "w", encoding="utf-8") as file: for predicted_label in all_predicted_labels: file.write(str(predicted_label) + "\n") else: for predicted_label in all_predicted_labels: print(str(predicted_label) + "\n") return all_predicted_labels
def predict(model: models.Model, instances: List[Dict], batch_size: int, save_to_file: str = None): batches = generate_batches(instances, batch_size) predicted_labels = [] all_predicted_labels = [] print(f"\nStarting predictions") for batch_inputs in tqdm(batches): batch_inputs.pop("labels") logits = model(**batch_inputs, training=False)['logits'] predicted_labels = list(tf.argmax(logits, axis=-1).numpy()) all_predicted_labels += predicted_labels if save_to_file: with open(save_to_file, 'w') as file: for predicted_label, instance in zip(all_predicted_labels, instances): file.write( f"{instance['sentence_id']}\t{ID_TO_CLASS[predicted_label]}\n" ) else: for predicted_label, instance in zip(all_predicted_labels, instances): print(f"{instance['sentence_id']}\t{ID_TO_CLASS[predicted_label]}")
def train(): with tf.Session() as sess: model = create_model(config) sess.run(tf.global_variables_initializer()) ## BE SURE THAT "[saver = tf.train.Saver()]" is after "[model = create_model(config)]" saver = tf.train.Saver() step = 0 print('start training') for no_epoch in range(1, config.epoches): batches = generate_batches(config.dialogs, config.batch_size) for no_batch in range(1, len(batches) + 1): _, loss = model.step(sess, batches[no_batch - 1], forward_only=False, mode='train') if step % 20 == 0: print('step{}'.format(step) + 'batch loss:{}'.format(loss)) step = step + 1 if no_epoch % config.save_epoch == 0: saver.save(sess, config.save_path + config.save_name, global_step=step) print('model saved at step ={}'.format(step)) print('finish training')
def get_candidates_and_references(self, pairs, arr_dep, k_beams=3): input_batches, _ = generate_batches(self.input_lang, self.output_lang, 1, pairs, return_dep_tree=True, arr_dep=arr_dep, max_degree=10, USE_CUDA=self.USE_CUDA) candidates = [ self.evaluate(input_batch, k_beams)[0] for input_batch in tqdm(input_batches) ] candidates = [' '.join(candidate[:-1]) for candidate in candidates] references = pairs[:, 1] references = [ self.ref_to_string(reference) for reference in references ] return candidates, references
def predict(model: models.Model, instances: List[Dict], batch_size: int, is_bert, save_to_file: str = None) -> List[int]: """ Makes predictions using model on instances and saves them in save_to_file. """ # for BERT, use the finetuned BERT model to make predictions if is_bert: test_ids, test_labels = ids_labels_from_instances(instances) test_preds = model.predict(test_ids, batch_size=batch_size) all_predicted_labels = tf.cast(test_preds >= 0.5, tf.int32) all_predicted_labels = tf.reshape(all_predicted_labels, -1) all_predicted_labels = list(all_predicted_labels.numpy()) # for GloVe embedding based models, use trained model to make predictions else: batches = generate_batches(instances, batch_size) predicted_labels = [] all_predicted_labels = [] print("Making predictions") for batch_inputs in tqdm(batches): batch_inputs.pop("labels") logits = model(**batch_inputs, training=False)["logits"] predicted_labels = list(tf.argmax(logits, axis=-1).numpy()) all_predicted_labels += predicted_labels # save the predictions file at the given file path if save_to_file: print(f"Saving predictions to filepath: {save_to_file}") with open(save_to_file, "w") as file: for predicted_label in all_predicted_labels: file.write(str(predicted_label) + "\n") else: for predicted_label in all_predicted_labels: print(str(predicted_label) + "\n") return all_predicted_labels
def train(model: models.Model, optimizer: optimizers.Optimizer, train_instances: List[Dict[str, np.ndarray]], validation_instances: List[Dict[str, np.ndarray]], num_epochs: int, batch_size: int, serialization_dir: str = None) -> tf.keras.Model: """ Trains a model on the give training instances as configured and stores the relevant files in serialization_dir. Returns model and some important metrics. """ print("\nGenerating Training batches:") train_batches = generate_batches(train_instances, batch_size) print("Generating Validation batches:") validation_batches = generate_batches(validation_instances, batch_size) train_batch_labels = [ batch_inputs.pop("labels") for batch_inputs in train_batches ] validation_batch_labels = [ batch_inputs.pop("labels") for batch_inputs in validation_batches ] tensorboard_logs_path = os.path.join(serialization_dir, f'tensorboard_logs') tensorboard_writer = tf.summary.create_file_writer(tensorboard_logs_path) best_epoch_validation_accuracy = float("-inf") best_epoch_validation_loss = float("inf") for epoch in range(num_epochs): print(f"\nEpoch {epoch}") total_training_loss = 0 total_correct_predictions, total_predictions = 0, 0 generator_tqdm = tqdm(list(zip(train_batches, train_batch_labels))) for index, (batch_inputs, batch_labels) in enumerate(generator_tqdm): with tf.GradientTape() as tape: logits = model(**batch_inputs, training=True)["logits"] loss_value = cross_entropy_loss(logits, batch_labels) grads = tape.gradient(loss_value, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) total_training_loss += loss_value batch_predictions = np.argmax(tf.nn.softmax(logits, axis=-1).numpy(), axis=-1) total_correct_predictions += ( batch_predictions == batch_labels).sum() total_predictions += batch_labels.shape[0] description = ( "Average training loss: %.2f Accuracy: %.2f " % (total_training_loss / (index + 1), total_correct_predictions / total_predictions)) generator_tqdm.set_description(description, refresh=False) average_training_loss = total_training_loss / len(train_batches) training_accuracy = total_correct_predictions / total_predictions total_validation_loss = 0 total_correct_predictions, total_predictions = 0, 0 generator_tqdm = tqdm( list(zip(validation_batches, validation_batch_labels))) for index, (batch_inputs, batch_labels) in enumerate(generator_tqdm): logits = model(**batch_inputs, training=False)["logits"] loss_value = cross_entropy_loss(logits, batch_labels) total_validation_loss += loss_value batch_predictions = np.argmax(tf.nn.softmax(logits, axis=-1).numpy(), axis=-1) total_correct_predictions += ( batch_predictions == batch_labels).sum() total_predictions += batch_labels.shape[0] description = ( "Average validation loss: %.2f Accuracy: %.2f " % (total_validation_loss / (index + 1), total_correct_predictions / total_predictions)) generator_tqdm.set_description(description, refresh=False) average_validation_loss = total_validation_loss / len( validation_batches) validation_accuracy = total_correct_predictions / total_predictions if validation_accuracy > best_epoch_validation_accuracy: print( "Model with best validation accuracy so far: %.2f. Saving the model." % (validation_accuracy)) classifier.save_weights( os.path.join(serialization_dir, f'model.ckpt')) best_epoch_validation_loss = average_validation_loss best_epoch_validation_accuracy = validation_accuracy with tensorboard_writer.as_default(): tf.summary.scalar("loss/training", average_training_loss, step=epoch) tf.summary.scalar("loss/validation", average_validation_loss, step=epoch) tf.summary.scalar("accuracy/training", training_accuracy, step=epoch) tf.summary.scalar("accuracy/validation", validation_accuracy, step=epoch) tensorboard_writer.flush() metrics = { "training_loss": float(average_training_loss), "validation_loss": float(average_validation_loss), "training_accuracy": float(training_accuracy), "best_epoch_validation_accuracy": float(best_epoch_validation_accuracy), "best_epoch_validation_loss": float(best_epoch_validation_loss) } print("Best epoch validation accuracy: %.4f, validation loss: %.4f" % (best_epoch_validation_accuracy, best_epoch_validation_loss)) return {"model": model, "metrics": metrics}
train_X, train_Y, vocab, reverse_vocab = process_data( train_data, label_to_id, vocab=vocab, vocab_size=args.vocab_size, max_tokens=args.sequence_length) print('Training data loaded.') print('\nLoading validation data...') validation_X, validation_Y, _, _ = process_data( validation_data, label_to_id, vocab=vocab, max_tokens=args.sequence_length) print('Validation data loaded.') print('\nGenerating batches...') train_batches = generate_batches(train_X, train_Y, args.batch_size) validation_batches = generate_batches(validation_X, validation_Y, args.batch_size) print('Batches finished generating.') optimizer = optimizers.Adam() if model is None: model_config = { 'vocab_size': args.vocab_size, 'embedding_dim': args.embed_dim, 'output_dim': len(label_to_id), 'num_layers': args.num_layers, 'dropout': 0.2, 'trainable_embeddings': True, 'transform_sequences': args.transform_sequences
val_bar = tqdm(desc='split=val', total=dataset.get_num_batches(args.batch_size), position=1, leave=True) try: for epoch_index in range(args.num_epochs): train_state['epoch_index'] = epoch_index # Iterate over training dataset # setup: batch generator, set loss and acc to 0, set train mode on dataset.set_split('train') batch_generator = generate_batches( dataset, batch_size=args.batch_size, # 128 device=args.device) running_loss = 0.0 running_acc = 0.0 classifier.train() al_loss = [] al_accu = [] for batch_index, batch_dict in enumerate(batch_generator): # the training routine is these 5 steps: # -------------------------------------- # step 1. zero the gradients optimizer.zero_grad() # step 2. compute the output y_pred = classifier(batch_dict['x_data'])
original_instance = {"text_tokens": "the film performances were awesome".split()} updates = ["worst", "okay", "cool"] updated_instances = [] for update in updates: updated_instance = copy.deepcopy(original_instance) updated_instance["text_tokens"][4] = update updated_instances.append(updated_instance) all_instances = [original_instance]+updated_instances layer_representations = {} for seq2vec_name in choices.keys(): model = models[seq2vec_name] vocab = vocabs[seq2vec_name] all_indexed_instances = index_instances(copy.deepcopy(all_instances), vocab) batches = generate_batches(all_indexed_instances, 4) layer_representations[seq2vec_name] = model(**batches[0], training=False)["layer_representations"] for seq2vec_name, representations in layer_representations.items(): representations = np.asarray(representations) differences_across_layers = {"worst": [], "okay": [], "cool": []} for layer_num in choices[seq2vec_name]: original_representation = representations[0, layer_num-1, :] updated_representations = representations[1:, layer_num-1,:] differences = [sum(np.abs(original_representation-updated_representation)) for updated_representation in updated_representations] differences_across_layers["worst"].append(float(differences[0])) differences_across_layers["okay"].append(float(differences[1])) differences_across_layers["cool"].append(float(differences[2]))
(index + 1), total_correct_predictions / total_predictions)) generator_tqdm.set_description(description, refresh=False) average_loss = total_eval_loss / len(eval_batches) eval_accuracy = total_correct_predictions / total_predictions print('Final evaluation accuracy: %.4f loss: %.4f' % (eval_accuracy, average_loss)) if __name__ == "__main__": parser = argparse.ArgumentParser( description="""Script to evaluate a trained model on data.""") parser.add_argument('model', help='Path to trained model directory') parser.add_argument('--test', help='Path to evaluation data.', default=r'./data/test.csv') parser.add_argument('--labels', help='Path to label dictionary.', default=r'./data/answers.json') args = parser.parse_args() data, label_to_id = load_eval_data(args.test, args.labels) print('\nLoading test data...') model, model_config, vocab, reverse_vocab = load_model(args.model) test_X, test_Y, vocab, reverse_vocab = process_data( data, label_to_id, vocab=vocab, vocab_size=model_config['vocab_size']) print('Test data loaded.') batch_size = 32 batches = generate_batches(test_X, test_Y, batch_size) print('Batches finished generating.') train_result = eval(model, batches)
def train(model: models.Model, optimizer: optimizers.Optimizer, train_instances: List[Dict[str, np.ndarray]], validation_instances: List[Dict[str, np.ndarray]], num_epochs: int, max_length_train: int, max_length_validation: int, embedding_dim: int, batch_size: int, number_of_clusters: int, hidden_units_in_autoencoder_layers: List[int], serialization_dir: str = None) -> tf.keras.Model: """ Trains a model on the give training instances as configured and stores the relevant files in serialization_dir. Returns model and some important metrics. """ tensorboard_logs_path = os.path.join(serialization_dir, f'tensorboard_logs') tensorboard_writer = tf.summary.create_file_writer(tensorboard_logs_path) #best_epoch_validation_accuracy = float("-inf") best_epoch_loss = float("inf") best_epoch_validation_silhouette_score = float("-inf") KMeans_1 = Clustering.KMeansClustering for epoch in range(num_epochs): print(f"\nEpoch {epoch}") total_training_loss = 0 #total_abstract_train_part_hidden_value = [] total_abstract_train_part_hidden_value = 0 total_abstract_validation_part_hidden_value = [] k = 0 list_of_trainable_variables = [] print("\nGenerating Training batches:") train_batches, embed_dim_train = generate_batches( train_instances, batch_size, max_length_train, embedding_dim) print("Generating Validation batches:") validation_batches, embed_dim_validation = generate_batches( validation_instances, batch_size, max_length_validation, embedding_dim) noise_data_train = np.random.normal( loc=0, scale=0.1, size=[len(train_instances) * 2, max_length_train, embed_dim_train]) noise_data_validation = np.random.normal( loc=0, scale=0.1, size=[ len(validation_instances) * 2, max_length_validation, embed_dim_validation ]) train_batch_tickers = [ batch_inputs.pop("Ticker") for batch_inputs in train_batches ] validation_batch_tickers = [ batch_inputs.pop("Ticker") for batch_inputs in validation_batches ] train_batch_dates = [ batch_inputs.pop("Date") for batch_inputs in train_batches ] validation_batch_dates = [ batch_inputs.pop("Date") for batch_inputs in validation_batches ] train_batch_asset_returns = [ batch_inputs.pop('Asset_Returns') for batch_inputs in train_batches ] validation_batch_asset_returns = [ batch_inputs.pop('Asset_Returns') for batch_inputs in validation_batches ] generator_tqdm = tqdm(list(zip(train_batches, train_batch_tickers))) for index, (batch_inputs, batch_tickers) in enumerate(generator_tqdm): with tf.GradientTape() as tape: # if epoch == 0: noise = noise_data_train[k * batch_size * 2:(k + 1) * batch_size * 2, :] fake_input, train_real_fake_labels = get_fake_sample( batch_inputs['inputs']) real_batch_inputs = batch_inputs['inputs'] batch_inputs['inputs'] = np.concatenate( (batch_inputs['inputs'], fake_input), axis=0) batch_inputs['real_fake_label'] = train_real_fake_labels batch_inputs['noise'] = noise batch_inputs[ 'Validation_Inputs'] = validation_batch_asset_returns[ index] batch_inputs['Training_Inputs'] = train_batch_asset_returns[ index] batch_inputs['batch_tickers'] = batch_tickers loss, f_new = model(batch_inputs, epoch, training=True) if epoch % 2 == 0: # batch_inputs['F_Updated_Value'] = f_new """ if epoch == 0: part_hidden_val = np.array(hidden_state).reshape(-1,2*sum(hidden_units_in_autoencoder_layers) )#np.sum(config.hidden_size) * 2 W = part_hidden_val.T U, sigma, VT = np.linalg.svd(W) sorted_indices = np.argsort(sigma) topk_evecs = VT[sorted_indices[:-number_of_clusters - 1:-1], :] F_new = topk_evecs.T batch_inputs['F_Updated_Value'] = F_new #total_abstract_train_part_hidden_value.append(part_hidden_val) total_abstract_train_part_hidden_value = part_hidden_val km = KMeans(n_clusters=number_of_clusters) cluster_labels_training = km.fit_predict(X=total_abstract_train_part_hidden_value) estimated_return_vector, estimated_covariance_matrix, cluster_weights, cluster_variance, trainable_variables_loss, list_of_trainable_variables = Calculate_Cluster_Weights.calculate_train_data_return_vector_for_all_stocks( cluster_labels_training, real_batch_inputs) asset_weights = [value for key,value in cluster_weights.items()] #asset_weights = np.perm(asset_weights,cluster_labels_training) estimated_return_vector = tf.concat(estimated_return_vector, axis=0) validation_period_out_of_sample_returns = validation_batches[index]['inputs'] mean_of_validation_period_out_of_sample_returns = tf.math.reduce_mean( validation_period_out_of_sample_returns, axis=1) #difference_vector = estimated_return_vector - mean_of_validation_period_out_of_sample_returns difference_vector = tf.math.pow(0,2) #difference_vector = tf.tensordot(difference_vector,asset_weights,axis=0) #mean_squared_loss_of_in_sample_and_out_of_sample_returns = tf.keras.losses.mean_squared_error( #mean_of_validation_period_out_of_sample_returns, estimated_return_vector) #loss += difference_vector silhouette_score_training = silhouette_score(total_abstract_train_part_hidden_value, cluster_labels_training) loss += 1e-4 * trainable_variables_loss # calculate_train_data_return_vector_for_all_stocks(cluster_labels_training,) else: batch_inputs.pop('F_Updated_Value',None) """ regularization_loss = 0 #list_of_trainable_variables = for var in model.trainable_variables: #print(var) list_of_trainable_variables.append(var) #if (var.name.find("Returns Vector") == -1) and (var.name.find("Estimated_Covariance") == -1): regularization_loss += tf.math.add_n([tf.nn.l2_loss(var)]) loss += 1e-4 * (regularization_loss) k += 1 grads = tape.gradient(loss, list_of_trainable_variables) optimizer.apply_gradients(zip(grads, list_of_trainable_variables)) total_training_loss += loss description = ("Average training loss: %.2f " % (total_training_loss / len(train_instances))) generator_tqdm.set_description(description, refresh=False) average_training_loss = total_training_loss / len(train_instances) if average_training_loss < best_epoch_loss: print( "Model with best training loss so far: %.2f. Saving the model." % (average_training_loss)) best_epoch_loss = average_training_loss model.save_weights(os.path.join(serialization_dir, f'model.ckpt')) """ if epoch % 10 == 0 and epoch != 0: print(total_abstract_train_part_hidden_value[0],total_abstract_train_hidden_state[0])# concatenated_total_abstract_train_part_hidden_value = tf.concat(total_abstract_train_part_hidden_value,axis=1) concatenated_total_abstract_train_part_hidden_value = concatenated_total_abstract_train_part_hidden_value.numpy() silhouette_score_for_k_means_clustering,cluster_labels, kmeans = KMeans.cluster_hidden_states(concatenated_total_abstract_train_part_hidden_value) """ """ total_validation_loss = 0 generator_tqdm = tqdm(list(zip(validation_batches, validation_batch_tickers))) k=0 for index, (batch_inputs,batch_tickers) in enumerate(generator_tqdm): #if epoch == 0: noise = noise_data_validation[k * batch_size * 2: (k + 1) * batch_size * 2, :] fake_input, train_real_fake_labels = get_fake_sample(batch_inputs['inputs']) batch_inputs['inputs'] = np.concatenate((batch_inputs['inputs'], fake_input), axis=0) batch_inputs['real_fake_label'] = train_real_fake_labels batch_inputs['noise'] = noise loss, hidden_state = model(batch_inputs, training=False) #if epoch % 10 == 0 and epoch != 0: #if epoch == 0: part_validation_hidden_val = np.array(hidden_state).reshape(-1, 2*sum(hidden_units_in_autoencoder_layers)) total_abstract_validation_part_hidden_value.append(part_validation_hidden_val) #print(total_abstract_validation_part_hidden_value) #total #kmeans.predict()# #print("") k += 1 #grads = tape.gradient(loss, model.trainable_variables) #optimizer.apply_gradients(zip(grads, model.trainable_variables)) total_validation_loss += loss description = ("Average validation loss: %.2f " % (total_validation_loss / (index + 1))) generator_tqdm.set_description(description, refresh=False) average_validation_loss = total_validation_loss / len(validation_batches) """ #if epoch % 10 == 0 and epoch != 0: #print(total_abstract_validation_part_hidden_value[0]) # """ concatenated_total_abstract_validation_part_hidden_value = tf.concat(total_abstract_validation_part_hidden_value, axis=0) concatenated_total_abstract_validation_part_hidden_value = concatenated_total_abstract_validation_part_hidden_value.numpy() km = KMeans(n_clusters=number_of_clusters) cluster_labels_validation = km.fit_predict(X=concatenated_total_abstract_validation_part_hidden_value) silhouette_score_validation = silhouette_score(concatenated_total_abstract_validation_part_hidden_value,cluster_labels_validation) print(silhouette_score_validation,average_validation_loss) if silhouette_score_validation > best_epoch_validation_silhouette_score and average_validation_loss < best_epoch_validation_loss: print("Model with best validation silhouette score so far: %.2f. Saving the model." % (silhouette_score_validation)) print("Model with best validation loss so far: %.2f. Saving the model." % (average_validation_loss)) model.save_weights(os.path.join(serialization_dir, f'model.ckpt')) best_epoch_validation_silhouette_score = silhouette_score_validation best_epoch_validation_loss = average_validation_loss #best_epoch_validation_accuracy = validation_accuracy """ with tensorboard_writer.as_default(): tf.summary.scalar("loss/training", average_training_loss, step=epoch) #tf.summary.scalar("loss/validation", average_validation_loss, step=epoch) #tf.summary.scalar("accuracy/training", training_accuracy, step=epoch) #tf.summary.scalar("accuracy/validation", best_epoch_validation_silhouette_score, step=epoch) tensorboard_writer.flush() metrics = { "training_loss": float(average_training_loss), #"validation_loss": float(average_validation_loss), #"training_accuracy": float(training_accuracy), #"best_epoch_validation_accuracy": float(best_epoch_validation_silhouette_score), #"best_epoch_validation_loss": float(best_epoch_validation_loss) } #print("Best epoch validation loss: %.4f" % best_epoch_validation_loss) return {"model": model, "metrics": metrics}
def run_experiment(n, m, batch_size, num_train_samples, num_test_samples, num_eval_steps, seed, numpy_seed, dynamics, x_goal, u_goal, qmat, rmat, lr_cost, lr_constraints, rtol, atol, dt, state_sampler, num_train_iterations): rng = random.PRNGKey(seed) onp.random.seed(numpy_seed) amat, bmat = dynamical.linearize_dynamics(dynamics, x_goal, u_goal, 0) amat = amat * dt bmat = bmat * dt qmat = qmat * dt rmat = rmat * dt true_lqr = module.LQR(A=amat, B=bmat, Q=qmat, R=rmat) opt_pmat = olinalg.solve_discrete_are(true_lqr.A, true_lqr.B, true_lqr.Q, true_lqr.R) # generate data rng, key = random.split(rng) train_xs = state_sampler(rng, num_train_samples) train_demos = generate_lqr_demos(train_xs, x_goal, true_lqr) rng, key = random.split(rng) test_xs = state_sampler(rng, num_test_samples) test_demos = generate_lqr_demos(test_xs, x_goal, true_lqr) # create a dummy batch to get dimensions batch_gen = data_util.generate_batches(train_demos, batch_size, drop_remainder=True, shuffle=True) placeholder_batch = next(batch_gen) # reset the batch generator batch_gen = data_util.generate_batches(train_demos, batch_size, drop_remainder=True, shuffle=True) # set up lagrangian for the constrained optimization params_init, get_lqr = module.lqr() def batch_loss(params, data): pmat, lqr = get_lqr(params) kmat = discrete.gain_matrix(pmat, lqr) policy = vectorize.vectorize("(i),()->(j)")(util.policy(kmat, x_goal)) us = policy(data.xs, np.zeros((), dtype=np.int32)) return loss(data.us, us) def constraints(params, data): del data pmat, lqr = get_lqr(params) return discrete.riccati_operator(pmat, lqr) - pmat mult_init, lagr_func, get_params = lagrangian.make_lagrangian( batch_loss, constraints) # set up training functions opt_init, opt_update, get_lagr_params = cga.cga_lagrange_min( lr_cost, lagr_func, lr_multipliers=lr_constraints) def convergence_test(x_new, x_old): return converge.max_diff_test(x_new, x_old, rtol, atol) @jax.jit def step(i, opt_state, data): params = get_lagr_params(opt_state) val, grads = jax.value_and_grad(lagr_func, (0, 1))(*params, data=data) logs = { "lagrangian": val, "train loss": batch_loss(get_params(params), data), } return opt_update(i, grads, opt_state, data=data), logs # initialize all parameters rng, params_key = random.split(rng) params = params_init(params_key, (n, m)) lagr_params = mult_init(params, data=placeholder_batch) opt_state = opt_init(lagr_params) # run first step but ignore updates to force the jit to compile step(0, opt_state, data=placeholder_batch) all_params = [] all_times = [] all_train_loss = [] all_lagrangian = [] for i in range(num_train_iterations): old_params = get_lagr_params(opt_state) all_params.append(old_params) # wait for the async dispatch to finish tree_util.tree_map(safe_block_until_ready, all_params) all_times.append(time.perf_counter()) opt_state, logs = step(i, opt_state, data=next(batch_gen)) # print(logs) all_train_loss.append(logs["train loss"]) all_lagrangian.append(logs["lagrangian"]) if convergence_test(get_lagr_params(opt_state), old_params): # print("CONVERGED!! Step:", i) break all_params.append(get_lagr_params(opt_state)) # wait for the async dispatch to finish tree_util.tree_map(safe_block_until_ready, all_params) all_times.append(time.perf_counter()) opt_costs = np.einsum("ij,ki,kj->k", opt_pmat, test_demos.xs - x_goal, test_demos.xs - x_goal) @jax.jit def evaluate_params(params): pmat, learned_lqr = get_lqr(get_params(params)) kmat = discrete.gain_matrix(pmat, learned_lqr) _, _, cs = evaluate_lqr_policy(test_demos.xs, x_goal, u_goal, kmat, true_lqr, num_eval_steps) costs = np.sum(cs, axis=-1) diff = costs - opt_costs test_loss = batch_loss(get_params(params), test_demos) return np.mean(diff), test_loss avg_diff, test_loss = zip(*[evaluate_params(p) for p in all_params]) return all_times, avg_diff, test_loss, all_train_loss, all_lagrangian