def train(): loading = False logs_path = cs.BASE_LOG_PATH + cs.MODEL_VAE tf.reset_default_graph() vae = ConVAE() vae.build_model() epochs = 12 noise = 0.8 merged_summary_op = write_summaries(vae) sess = tf.Session() saver = tf.train.Saver(max_to_keep=10) if loading: sess.run(tf.global_variables_initializer()) saver.restore(sess, tf.train.latest_checkpoint(logs_path)) latest_checkpoint_path = tf.train.latest_checkpoint(logs_path) checkpoint_number = latest_checkpoint_path.split(".")[0] checkpoint_number = int(checkpoint_number.split("_")[-1]) print("loading checkpoint_number =", checkpoint_number) else: sess.run(tf.global_variables_initializer()) checkpoint_number = 0 summary_writer = tf.summary.FileWriter(logs_path, graph=sess.graph) summary_writer.add_graph(sess.graph) loop_counter = 1 for e in tqdm(range(checkpoint_number, checkpoint_number + epochs)): print() path_generator = os_utils.iterate_data( cs.BASE_DATA_PATH + cs.DATA_TRAIN_VIDEOS, "mp4") batch_counter = 1 start_time = time.time() for video_path in path_generator: # ====================================== # get batches to feed into the network # ====================================== batch_x = get_batch(video_path) if batch_x is None: # print("video_path", video_path) continue else: print("video_path", video_path) print("video number =", batch_counter, "..... batch_x.shape", batch_x.shape, " loop_counter =", checkpoint_number + loop_counter) g_loss, l_loss, _, summary = sess.run( [ vae.generation_loss, vae.latent_loss, vae.opt, merged_summary_op ], feed_dict={ vae.inputs_: batch_x, vae.targets_: batch_x.copy(), vae.noise_var: noise }) # ============================== # Write logs at every iteration # ============================== summary_writer.add_summary(summary, checkpoint_number + loop_counter) print( "Epoch: {}/{}...".format(e + 1 - checkpoint_number, epochs), "Generation loss: {:.4f}".format(np.mean(g_loss)), "Latent loss: {:.4f}".format(np.mean(l_loss)), "Total loss: {:.4f}".format(np.mean(l_loss) + np.mean(g_loss))) # if batch_counter % 2 == 0: # print("saving the model at epoch", checkpoint_number + loop_counter) # saver.save(sess, os.path.join(logs_path, 'encoder_epoch_number_{}.ckpt' # .format(checkpoint_number + loop_counter))) batch_counter += 1 loop_counter += 1 if batch_counter == 2: end_time = time.time() print( "==============================================================================================" ) print("Epoch Number", e, "has ended in", end_time - start_time, "seconds for", batch_counter, "videos") print( "==============================================================================================" ) break if e % 10 == 0: print("################################################") print("saving the model at epoch", checkpoint_number + loop_counter) print("################################################") saver.save( sess, os.path.join( logs_path, 'encoder_epoch_number_{}.ckpt'.format(checkpoint_number + loop_counter))) # ========================= # Freeze the session graph # ========================= # freeze_model(sess, logs_path, tf.train.latest_checkpoint(logs_path), cae) utility.freeze_model(sess, logs_path, tf.train.latest_checkpoint(logs_path), vae, "encoder_train.pb", cs.VAE_FREEZED_PB_NAME) print( "Run the command line:\n--> tensorboard --logdir={}".format(logs_path), "\nThen open http://0.0.0.0:6006/ into your web browser") path_generator = os_utils.iterate_data(cs.BASE_DATA_PATH, "mp4") for video_path in path_generator: test_frame = get_batch(video_path) if test_frame is not None: test_frame = test_frame[20:40, :, :, :] reconstructed = sess.run(vae.decoded, feed_dict={vae.inputs_: test_frame}) display_reconstruction_results(test_frame, reconstructed) break sess.close()
def train(): epochs = 50 sampling_number = 70 encoder_logs_path = cs.BASE_LOG_PATH + cs.MODEL_CONV_AE_1 path_generator = os_utils.iterate_data( cs.BASE_DATA_PATH + cs.DATA_BG_TRAIN_VIDEO, "mp4") logs_path = cs.BASE_LOG_PATH + cs.MODEL_LSTM checkpoint_number = 0 loop_counter = 1 graph = tf.Graph() # Add nodes to the graph with graph.as_default(): val_acc = tf.Variable(0.0, tf.float32) tot_loss = tf.Variable(0.0, tf.float32) rnn = RecurrentNetwork(lstm_size=128, batch_len=BATCH_SIZE, output_nodes=14, learning_rate=0.001) rnn.build_model() stage_1_ip, stage_2_ip = get_encoded_embeddings(encoder_logs_path) prediction = tf.argmax(rnn.predictions, 1) label_encoder, num_classes = get_label_enocder(path_generator) with graph.as_default(): merged_summary_op = write_summaries(val_acc, tot_loss) summary_writer = tf.summary.FileWriter(logs_path, graph=graph) summary_writer.add_graph(graph) saver = tf.train.Saver(max_to_keep=4) loop_counter = 1 with tf.Session(graph=graph) as sess: sess.run(tf.global_variables_initializer()) iteration = 1 tf.get_default_graph().finalize() for e in range(epochs): sampling_list = random.sample(range(0, 419), sampling_number) start_time = time.time() total_loss = 0 validation_accuracy = 0 state = sess.run(rnn.initial_state) path_generator = os_utils.iterate_data( cs.BASE_DATA_PATH + cs.DATA_BG_TRAIN_VIDEO, "mp4") batch_counter = 0 for video_path in path_generator: batch_x = get_batch(video_path) batch_y = get_target_name(video_path) if batch_x is None: continue encoded_batch = sess.run(stage_2_ip, feed_dict={stage_1_ip: batch_x}) encoded_batch = encoded_batch.reshape( (1, encoded_batch.shape[0], encoded_batch.shape[1])) # print(encoded_batch.shape) feed = { rnn.inputs_: encoded_batch, rnn.targets_: label_encoder.transform([batch_y]), rnn.keep_prob: 0.80, rnn.initial_state: state } if batch_counter in sampling_list: network_prediction = sess.run([prediction], feed_dict=feed) print( "validation =======> network_prediction: {}".format( network_prediction[0][0]), "and ground truth: {}".format(batch_y - 1)) # print(network_prediction[0]) # print(batch_y-1) if network_prediction[0][0] == batch_y - 1: validation_accuracy += 1 else: batch_loss, state, _ = sess.run( [rnn.loss, rnn.final_state, rnn.optimizer], feed_dict=feed) total_loss += batch_loss print("Epoch: {}/{}".format(e, epochs), "Video Number: {}".format(batch_counter), "Batch Loss: {:.3f}".format(batch_loss)) iteration += 1 batch_counter += 1 loop_counter += 1 if batch_counter == 420: total_loss = total_loss / 420 end_time = time.time() print( "===========================================================================================" ) print( "Epoch Number", e, "has ended in", end_time - start_time, "seconds for", batch_counter, "videos", "total loss is = {:.3f}".format(total_loss), "validation accuracy is = {}".format( 100 * (validation_accuracy / len(sampling_list)))) print( "===========================================================================================" ) feed = { val_acc: validation_accuracy / len(sampling_list), tot_loss: total_loss } summary = sess.run(merged_summary_op, feed_dict=feed) summary_writer.add_summary(summary, e) break if e % 30 == 0: print("################################################") print("saving the model at epoch", checkpoint_number + loop_counter) print("################################################") saver.save( sess, os.path.join( logs_path, 'lstm_loop_count_{}.ckpt'.format(checkpoint_number + loop_counter))) print( "Run the command line:\n--> tensorboard --logdir={}".format(logs_path), "\nThen open http://0.0.0.0:6006/ into your web browser") rnn.process_node_names() utility.freeze_model(sess, logs_path, tf.train.latest_checkpoint(logs_path), rnn, "lstm_train.pb", cs.LSTM_FREEZED_PB_NAME) sess.close()
def train(): """ This function builds the graph and performs the training """ epochs = 150 # epochs: Number of iterations for which training will be performed loading = False # loading : flag for loading an already trained model logs_path = cs.BASE_LOG_PATH + cs.MODEL_CONV_AE_1 # logs_path : path to store checkpoint and summary events tf.reset_default_graph() cae = ConVAE() cae.build_model() merged_summary_op = write_summaries(cae) sess = tf.Session() saver = tf.train.Saver(max_to_keep=10) # ======================================================================= # If loading flag is true then load the latest model form the logs_path # ======================================================================= if loading: sess.run(tf.global_variables_initializer()) saver.restore(sess, tf.train.latest_checkpoint(logs_path)) latest_checkpoint_path = tf.train.latest_checkpoint(logs_path) checkpoint_number = latest_checkpoint_path.split(".")[0] checkpoint_number = int(checkpoint_number.split("_")[-1]) print("loading checkpoint_number =", checkpoint_number) else: sess.run(tf.global_variables_initializer()) checkpoint_number = 0 summary_writer = tf.summary.FileWriter(logs_path, graph=sess.graph) summary_writer.add_graph(sess.graph) loop_counter = 1 for e in tqdm(range(checkpoint_number, checkpoint_number + epochs)): print() path_generator = os_utils.iterate_data( cs.BASE_DATA_PATH + cs.DATA_BG_TRAIN_VIDEO, "mp4") batch_counter = 0 start_time = time.time() for video_path in path_generator: # ====================================== # get batches to feed into the network # ====================================== batch_x = get_batch(video_path) if batch_x is None: continue else: print("video_path", video_path) print("video number =", batch_counter, "..... batch_x.shape", batch_x.shape, " loop_counter =", checkpoint_number + loop_counter) batch_loss, _, summary = sess.run( [cae.loss, cae.opt, merged_summary_op], feed_dict={ cae.inputs_: batch_x, cae.targets_: batch_x.copy() }) # ============================== # Write logs at every iteration # ============================== summary_writer.add_summary(summary, checkpoint_number + loop_counter) print("Epoch: {}/{}...".format(e + 1 - checkpoint_number, epochs), "Training loss: {:.4f}".format(batch_loss)) # if batch_counter % 2 == 0: # print("saving the model at epoch", checkpoint_number + loop_counter) # saver.save(sess, os.path.join(logs_path, 'encoder_epoch_number_{}.ckpt' # .format(checkpoint_number + loop_counter))) batch_counter += 1 loop_counter += 1 if batch_counter == 420: end_time = time.time() print( "==============================================================================================" ) print("Epoch Number", e, "has ended in", end_time - start_time, "seconds for", batch_counter, "videos") print( "==============================================================================================" ) # break if e % 10 == 0: print("################################################") print("saving the model at epoch", checkpoint_number + loop_counter) print("################################################") saver.save( sess, os.path.join( logs_path, 'encoder_epoch_number_{}.ckpt'.format(checkpoint_number + loop_counter))) # ========================= # Freeze the session graph # ========================= cae.process_node_names() utility.freeze_model(sess, logs_path, tf.train.latest_checkpoint(logs_path), cae, "encoder_train.pb", cs.ENCODER1_FREEZED_PB_NAME) print( "Run the command line:\n--> tensorboard --logdir={}".format(logs_path), "\nThen open http://0.0.0.0:6006/ into your web browser") path_generator = os_utils.iterate_test_data( cs.BASE_DATA_PATH + cs.DATA_BG_TRAIN_VIDEO, "mp4") # ============================================================== # Now testing the performance of our model on an unknown data # ============================================================== for video_path in path_generator: test_frame = get_batch(video_path) if test_frame is not None: test_frame = test_frame[20:40, :, :, :] reconstructed = sess.run(cae.decoded, feed_dict={cae.inputs_: test_frame}) display_reconstruction_results(test_frame, reconstructed) break sess.close()