def train():
    loading = False
    logs_path = cs.BASE_LOG_PATH + cs.MODEL_VAE
    tf.reset_default_graph()
    vae = ConVAE()
    vae.build_model()
    epochs = 12
    noise = 0.8
    merged_summary_op = write_summaries(vae)
    sess = tf.Session()
    saver = tf.train.Saver(max_to_keep=10)

    if loading:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, tf.train.latest_checkpoint(logs_path))
        latest_checkpoint_path = tf.train.latest_checkpoint(logs_path)
        checkpoint_number = latest_checkpoint_path.split(".")[0]
        checkpoint_number = int(checkpoint_number.split("_")[-1])
        print("loading checkpoint_number =", checkpoint_number)

    else:
        sess.run(tf.global_variables_initializer())
        checkpoint_number = 0

    summary_writer = tf.summary.FileWriter(logs_path, graph=sess.graph)
    summary_writer.add_graph(sess.graph)

    loop_counter = 1

    for e in tqdm(range(checkpoint_number, checkpoint_number + epochs)):
        print()
        path_generator = os_utils.iterate_data(
            cs.BASE_DATA_PATH + cs.DATA_TRAIN_VIDEOS, "mp4")
        batch_counter = 1
        start_time = time.time()

        for video_path in path_generator:
            # ======================================
            # get batches to feed into the network
            # ======================================
            batch_x = get_batch(video_path)

            if batch_x is None:
                # print("video_path", video_path)
                continue

            else:
                print("video_path", video_path)
                print("video number =", batch_counter, "..... batch_x.shape",
                      batch_x.shape, " loop_counter =",
                      checkpoint_number + loop_counter)

                g_loss, l_loss, _, summary = sess.run(
                    [
                        vae.generation_loss, vae.latent_loss, vae.opt,
                        merged_summary_op
                    ],
                    feed_dict={
                        vae.inputs_: batch_x,
                        vae.targets_: batch_x.copy(),
                        vae.noise_var: noise
                    })

            # ==============================
            # Write logs at every iteration
            # ==============================
            summary_writer.add_summary(summary,
                                       checkpoint_number + loop_counter)

            print(
                "Epoch: {}/{}...".format(e + 1 - checkpoint_number, epochs),
                "Generation loss: {:.4f}".format(np.mean(g_loss)),
                "Latent loss: {:.4f}".format(np.mean(l_loss)),
                "Total loss: {:.4f}".format(np.mean(l_loss) + np.mean(g_loss)))

            # if batch_counter % 2 == 0:
            #     print("saving the model at epoch", checkpoint_number + loop_counter)
            #     saver.save(sess, os.path.join(logs_path, 'encoder_epoch_number_{}.ckpt'
            #                                   .format(checkpoint_number + loop_counter)))

            batch_counter += 1
            loop_counter += 1
            if batch_counter == 2:
                end_time = time.time()
                print(
                    "=============================================================================================="
                )
                print("Epoch Number", e, "has ended in", end_time - start_time,
                      "seconds for", batch_counter, "videos")
                print(
                    "=============================================================================================="
                )

                break

        if e % 10 == 0:
            print("################################################")
            print("saving the model at epoch",
                  checkpoint_number + loop_counter)
            print("################################################")

            saver.save(
                sess,
                os.path.join(
                    logs_path,
                    'encoder_epoch_number_{}.ckpt'.format(checkpoint_number +
                                                          loop_counter)))
    # =========================
    # Freeze the session graph
    # =========================
    # freeze_model(sess, logs_path, tf.train.latest_checkpoint(logs_path), cae)
    utility.freeze_model(sess, logs_path,
                         tf.train.latest_checkpoint(logs_path), vae,
                         "encoder_train.pb", cs.VAE_FREEZED_PB_NAME)

    print(
        "Run the command line:\n--> tensorboard --logdir={}".format(logs_path),
        "\nThen open http://0.0.0.0:6006/ into your web browser")

    path_generator = os_utils.iterate_data(cs.BASE_DATA_PATH, "mp4")

    for video_path in path_generator:
        test_frame = get_batch(video_path)

        if test_frame is not None:

            test_frame = test_frame[20:40, :, :, :]
            reconstructed = sess.run(vae.decoded,
                                     feed_dict={vae.inputs_: test_frame})
            display_reconstruction_results(test_frame, reconstructed)

            break

    sess.close()
Пример #2
0
def train():
    epochs = 50
    sampling_number = 70
    encoder_logs_path = cs.BASE_LOG_PATH + cs.MODEL_CONV_AE_1
    path_generator = os_utils.iterate_data(
        cs.BASE_DATA_PATH + cs.DATA_BG_TRAIN_VIDEO, "mp4")
    logs_path = cs.BASE_LOG_PATH + cs.MODEL_LSTM
    checkpoint_number = 0
    loop_counter = 1

    graph = tf.Graph()
    # Add nodes to the graph
    with graph.as_default():
        val_acc = tf.Variable(0.0, tf.float32)
        tot_loss = tf.Variable(0.0, tf.float32)

        rnn = RecurrentNetwork(lstm_size=128,
                               batch_len=BATCH_SIZE,
                               output_nodes=14,
                               learning_rate=0.001)
        rnn.build_model()
        stage_1_ip, stage_2_ip = get_encoded_embeddings(encoder_logs_path)
        prediction = tf.argmax(rnn.predictions, 1)

    label_encoder, num_classes = get_label_enocder(path_generator)

    with graph.as_default():
        merged_summary_op = write_summaries(val_acc, tot_loss)
        summary_writer = tf.summary.FileWriter(logs_path, graph=graph)
        summary_writer.add_graph(graph)
        saver = tf.train.Saver(max_to_keep=4)

    loop_counter = 1

    with tf.Session(graph=graph) as sess:

        sess.run(tf.global_variables_initializer())
        iteration = 1
        tf.get_default_graph().finalize()
        for e in range(epochs):
            sampling_list = random.sample(range(0, 419), sampling_number)
            start_time = time.time()
            total_loss = 0
            validation_accuracy = 0
            state = sess.run(rnn.initial_state)

            path_generator = os_utils.iterate_data(
                cs.BASE_DATA_PATH + cs.DATA_BG_TRAIN_VIDEO, "mp4")

            batch_counter = 0
            for video_path in path_generator:

                batch_x = get_batch(video_path)
                batch_y = get_target_name(video_path)

                if batch_x is None:
                    continue

                encoded_batch = sess.run(stage_2_ip,
                                         feed_dict={stage_1_ip: batch_x})
                encoded_batch = encoded_batch.reshape(
                    (1, encoded_batch.shape[0], encoded_batch.shape[1]))

                # print(encoded_batch.shape)
                feed = {
                    rnn.inputs_: encoded_batch,
                    rnn.targets_: label_encoder.transform([batch_y]),
                    rnn.keep_prob: 0.80,
                    rnn.initial_state: state
                }

                if batch_counter in sampling_list:
                    network_prediction = sess.run([prediction], feed_dict=feed)
                    print(
                        "validation =======> network_prediction: {}".format(
                            network_prediction[0][0]),
                        "and ground truth: {}".format(batch_y - 1))
                    # print(network_prediction[0])
                    # print(batch_y-1)
                    if network_prediction[0][0] == batch_y - 1:
                        validation_accuracy += 1

                else:
                    batch_loss, state, _ = sess.run(
                        [rnn.loss, rnn.final_state, rnn.optimizer],
                        feed_dict=feed)

                    total_loss += batch_loss

                    print("Epoch: {}/{}".format(e, epochs),
                          "Video Number: {}".format(batch_counter),
                          "Batch Loss: {:.3f}".format(batch_loss))
                    iteration += 1

                batch_counter += 1
                loop_counter += 1

                if batch_counter == 420:
                    total_loss = total_loss / 420
                    end_time = time.time()
                    print(
                        "==========================================================================================="
                    )
                    print(
                        "Epoch Number", e, "has ended in",
                        end_time - start_time, "seconds for", batch_counter,
                        "videos", "total loss is = {:.3f}".format(total_loss),
                        "validation accuracy is = {}".format(
                            100 * (validation_accuracy / len(sampling_list))))
                    print(
                        "==========================================================================================="
                    )
                    feed = {
                        val_acc: validation_accuracy / len(sampling_list),
                        tot_loss: total_loss
                    }
                    summary = sess.run(merged_summary_op, feed_dict=feed)
                    summary_writer.add_summary(summary, e)

                    break

            if e % 30 == 0:
                print("################################################")
                print("saving the model at epoch",
                      checkpoint_number + loop_counter)
                print("################################################")

                saver.save(
                    sess,
                    os.path.join(
                        logs_path,
                        'lstm_loop_count_{}.ckpt'.format(checkpoint_number +
                                                         loop_counter)))

    print(
        "Run the command line:\n--> tensorboard --logdir={}".format(logs_path),
        "\nThen open http://0.0.0.0:6006/ into your web browser")

    rnn.process_node_names()
    utility.freeze_model(sess, logs_path,
                         tf.train.latest_checkpoint(logs_path), rnn,
                         "lstm_train.pb", cs.LSTM_FREEZED_PB_NAME)

    sess.close()
Пример #3
0
def train():
    """ This function builds the graph and performs the training """

    epochs = 150  # epochs: Number of iterations for which training will be performed
    loading = False  # loading : flag for loading an already trained model
    logs_path = cs.BASE_LOG_PATH + cs.MODEL_CONV_AE_1  # logs_path : path to store checkpoint and summary events
    tf.reset_default_graph()
    cae = ConVAE()
    cae.build_model()
    merged_summary_op = write_summaries(cae)
    sess = tf.Session()
    saver = tf.train.Saver(max_to_keep=10)

    # =======================================================================
    # If loading flag is true then load the latest model form the logs_path
    # =======================================================================
    if loading:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, tf.train.latest_checkpoint(logs_path))
        latest_checkpoint_path = tf.train.latest_checkpoint(logs_path)
        checkpoint_number = latest_checkpoint_path.split(".")[0]
        checkpoint_number = int(checkpoint_number.split("_")[-1])
        print("loading checkpoint_number =", checkpoint_number)

    else:
        sess.run(tf.global_variables_initializer())
        checkpoint_number = 0

    summary_writer = tf.summary.FileWriter(logs_path, graph=sess.graph)
    summary_writer.add_graph(sess.graph)

    loop_counter = 1

    for e in tqdm(range(checkpoint_number, checkpoint_number + epochs)):
        print()
        path_generator = os_utils.iterate_data(
            cs.BASE_DATA_PATH + cs.DATA_BG_TRAIN_VIDEO, "mp4")
        batch_counter = 0
        start_time = time.time()

        for video_path in path_generator:
            # ======================================
            # get batches to feed into the network
            # ======================================
            batch_x = get_batch(video_path)

            if batch_x is None:
                continue

            else:
                print("video_path", video_path)
                print("video number =", batch_counter, "..... batch_x.shape",
                      batch_x.shape, " loop_counter =",
                      checkpoint_number + loop_counter)

                batch_loss, _, summary = sess.run(
                    [cae.loss, cae.opt, merged_summary_op],
                    feed_dict={
                        cae.inputs_: batch_x,
                        cae.targets_: batch_x.copy()
                    })

            # ==============================
            # Write logs at every iteration
            # ==============================
            summary_writer.add_summary(summary,
                                       checkpoint_number + loop_counter)

            print("Epoch: {}/{}...".format(e + 1 - checkpoint_number, epochs),
                  "Training loss: {:.4f}".format(batch_loss))

            # if batch_counter % 2 == 0:
            #     print("saving the model at epoch", checkpoint_number + loop_counter)
            #     saver.save(sess, os.path.join(logs_path, 'encoder_epoch_number_{}.ckpt'
            #                                   .format(checkpoint_number + loop_counter)))

            batch_counter += 1
            loop_counter += 1
            if batch_counter == 420:
                end_time = time.time()
                print(
                    "=============================================================================================="
                )
                print("Epoch Number", e, "has ended in", end_time - start_time,
                      "seconds for", batch_counter, "videos")
                print(
                    "=============================================================================================="
                )

                # break

        if e % 10 == 0:
            print("################################################")
            print("saving the model at epoch",
                  checkpoint_number + loop_counter)
            print("################################################")

            saver.save(
                sess,
                os.path.join(
                    logs_path,
                    'encoder_epoch_number_{}.ckpt'.format(checkpoint_number +
                                                          loop_counter)))
    # =========================
    # Freeze the session graph
    # =========================
    cae.process_node_names()
    utility.freeze_model(sess, logs_path,
                         tf.train.latest_checkpoint(logs_path), cae,
                         "encoder_train.pb", cs.ENCODER1_FREEZED_PB_NAME)

    print(
        "Run the command line:\n--> tensorboard --logdir={}".format(logs_path),
        "\nThen open http://0.0.0.0:6006/ into your web browser")

    path_generator = os_utils.iterate_test_data(
        cs.BASE_DATA_PATH + cs.DATA_BG_TRAIN_VIDEO, "mp4")

    # ==============================================================
    # Now testing the performance of our model on an unknown data
    # ==============================================================
    for video_path in path_generator:
        test_frame = get_batch(video_path)

        if test_frame is not None:

            test_frame = test_frame[20:40, :, :, :]
            reconstructed = sess.run(cae.decoded,
                                     feed_dict={cae.inputs_: test_frame})
            display_reconstruction_results(test_frame, reconstructed)

            break

    sess.close()