Ejemplo n.º 1
0
def parse_function(serialized):
    IMAGE_SIZE_height = dataset_characteristics.get_image_height()
    IMAGE_SIZE_width = dataset_characteristics.get_image_width()
    IMAGE_PIXELS = IMAGE_SIZE_height * IMAGE_SIZE_width

    features = \
        {
            'image_anchor': tf.io.FixedLenFeature([], tf.string),
            'image_neighbor': tf.io.FixedLenFeature([], tf.string),
            'image_distant': tf.io.FixedLenFeature([], tf.string),
            'label_anchor': tf.io.FixedLenFeature([], tf.int64),
            'label_neighbor': tf.io.FixedLenFeature([], tf.int64),
            'label_distant': tf.io.FixedLenFeature([], tf.int64)
        }

    parsed_example = tf.io.parse_single_example(serialized=serialized,
                                             features=features)

    # image_anchor = tf.decode_raw(parsed_example['image_anchor'], tf.uint8)
    # image_neighbor = tf.decode_raw(parsed_example['image_neighbor'], tf.uint8)
    # image_distant = tf.decode_raw(parsed_example['image_distant'], tf.uint8)

    # https://www.tensorflow.org/api_docs/python
    image_anchor = tf.compat.v1.decode_raw(parsed_example['image_anchor'], tf.uint8)
    image_neighbor = tf.compat.v1.decode_raw(parsed_example['image_neighbor'], tf.uint8)
    image_distant = tf.compat.v1.decode_raw(parsed_example['image_distant'], tf.uint8)
    label_anchor = parsed_example['label_anchor']
    label_neighbor = parsed_example['label_neighbor']
    label_distant = parsed_example['label_distant']

    image_anchor.set_shape((IMAGE_PIXELS))
    image_neighbor.set_shape((IMAGE_PIXELS))
    image_distant.set_shape((IMAGE_PIXELS))

    return image_anchor, image_neighbor, image_distant, label_anchor, label_neighbor, label_distant
 def __init__(self, checkpoint_dir, model_dir_, deep_model, batch_size,
              feature_space_dimension):
     self.checkpoint_dir = checkpoint_dir
     self.model_dir_ = model_dir_
     self.batch_size = batch_size
     self.feature_space_dimension = feature_space_dimension
     self.batch_size = batch_size
     self.n_samples = None
     self.n_batches = None
     self.image_height = dataset_characteristics.get_image_height()
     self.image_width = dataset_characteristics.get_image_width()
     self.image_n_channels = dataset_characteristics.get_image_n_channels()
def evaluate_embedding_space(path_save_network_model, model_dir_, deep_model,
                             feature_space_dimension, latent_space_dimension,
                             n_res_blocks, margin_in_loss, loss_type):
    which_epoch_to_load_NN_model = 45
    path_save_embeddings_of_test_data = ".\\results\\" + deep_model + "\\embedding_test_set\\"
    image_height = dataset_characteristics.get_image_height()
    image_width = dataset_characteristics.get_image_width()
    image_n_channels = dataset_characteristics.get_image_n_channels()
    # path_save_network_model = "./network_model/ResNet/"
    if deep_model == "CNN":
        siamese = CNN_Siamese.CNN_Siamese(
            loss_type=loss_type,
            feature_space_dimension=feature_space_dimension,
            margin_in_loss=margin_in_loss)
    elif deep_model == "ResNet":
        siamese = ResNet_Siamese.ResNet_Siamese(
            loss_type=loss_type,
            feature_space_dimension=feature_space_dimension,
            latent_space_dimension=latent_space_dimension,
            n_res_blocks=n_res_blocks,
            margin_in_loss=margin_in_loss,
            is_train=True)
    evaluate_ = Evaluate_embedding_space(
        checkpoint_dir=path_save_network_model +
        str(which_epoch_to_load_NN_model) + "/",
        model_dir_=model_dir_)
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
    # X_train = X_train[:2000, :, :]
    # X_train = X_train.reshape((X_train.shape[0], image_height, image_width, image_n_channels))
    X_test = X_test.reshape(
        (X_test.shape[0], image_height, image_width, image_n_channels))
    # evaluate_.embed_the_data(X=X_train, labels=y_train, siamese=siamese, path_save_embeddings_of_test_data=path_save_embeddings_of_test_data)
    embedding, labels = evaluate_.embed_the_data(
        X=X_test,
        labels=y_test,
        siamese=siamese,
        path_save_embeddings_of_test_data=path_save_embeddings_of_test_data)
    evaluate_.classify_with_1NN(
        embedding,
        labels,
        path_to_save=path_save_embeddings_of_test_data + "KNN/")
def read_batches_data(loaded_batch_names, batch_size, path_base_data_numpy):
    # batch_size must be the same as batch_size in the code of generating batches
    tissue_type_list = [
        "00_TUMOR", "01_STROMA", "02_MUCUS", "03_LYMPHO", "04_DEBRIS",
        "05_SMOOTH_MUSCLE", "06_ADIPOSE", "07_BACKGROUND", "08_NORMAL"
    ]
    image_height = dataset_characteristics.get_image_height()
    image_width = dataset_characteristics.get_image_width()
    image_n_channels = dataset_characteristics.get_image_n_channels()
    image_n_channels = dataset_characteristics.get_image_n_channels()
    paths_data_files = glob.glob(path_base_data_numpy + "**\\*.npy")
    loaded_batch = np.zeros(
        (batch_size, image_height, image_width, image_n_channels))
    loaded_labels = np.zeros((batch_size, ))
    for index_in_batch, file_name in enumerate(loaded_batch_names):
        path_file_in_batch = [i for i in paths_data_files if file_name in i]
        assert len(path_file_in_batch) == 1
        path_ = path_file_in_batch[0]
        class_label = path_.split("\\")[-2]
        class_index = tissue_type_list.index(class_label)
        file_in_batch = np.load(path_)
        loaded_batch[index_in_batch, :, :, :] = file_in_batch
        loaded_labels[index_in_batch] = class_index
    return loaded_batch, loaded_labels
def train_embedding_space(deep_model, n_res_blocks, batch_size, learning_rate,
                          path_save_network_model, model_dir_,
                          feature_space_dimension, latent_space_dimension,
                          margin_in_loss, loss_type):
    #================================ settings:
    save_plot_embedding_space = True
    save_points_in_embedding_space = True
    load_saved_network_model = False
    which_epoch_to_load_NN_model = 0
    num_epoch = 1000
    save_network_model_every_how_many_epochs = 1
    save_embedding_every_how_many_epochs = 1
    STEPS_PER_EPOCH_TRAIN = 704
    # STEPS_PER_EPOCH_TRAIN = 16
    n_samples_plot = 2000  #--> if None, plot all
    image_height = dataset_characteristics.get_image_height()
    image_width = dataset_characteristics.get_image_width()
    image_n_channels = dataset_characteristics.get_image_n_channels()
    # path_tfrecords_train = 'C:\\Users\\bghojogh\\Desktop\\My_PhD\\PhD_projects\\Pathology\\dataset\\TCGA_triplets\\tfrecord\\triplets.tfrecords'
    # path_tfrecords_train = 'C:\\Users\\bghojogh\\Desktop\\My_PhD\\PhD_projects\\Fisher_loss\\codes\\4_make_triplets\\2_MNIST\\triplets\\MNIST_1024_triplets\\tfrecord\\triplets.tfrecords'
    path_tfrecords_train = 'C:\\Users\\bghojogh\\Desktop\\My_PhD\\PhD_projects\\Fisher_loss\\codes\\4_make_triplets\\2_MNIST\\triplets\\MNIST_500_triplets\\tfrecord\\triplets.tfrecords'
    path_save_embedding_space = ".\\results\\" + deep_model + "\\embedding_train_set\\"
    path_save_loss = ".\\loss_saved\\"
    #================================

    train_dataset = tf.data.TFRecordDataset([path_tfrecords_train])
    train_dataset = train_dataset.map(Utils.parse_function)
    train_dataset = train_dataset.map(Utils.normalize_triplets)

    num_repeat = None
    train_dataset = train_dataset.repeat(num_repeat)
    train_dataset = train_dataset.shuffle(buffer_size=1024)
    train_dataset = train_dataset.batch(batch_size)
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   train_dataset.output_types,
                                                   train_dataset.output_shapes)

    next_element = iterator.get_next()
    # training_iterator = train_dataset.make_initializable_iterator()
    training_iterator = tf.data.make_initializable_iterator(train_dataset)

    # Siamese:
    if deep_model == "CNN":
        siamese = CNN_Siamese.CNN_Siamese(
            loss_type=loss_type,
            feature_space_dimension=feature_space_dimension,
            margin_in_loss=margin_in_loss)
    elif deep_model == "ResNet":
        siamese = ResNet_Siamese.ResNet_Siamese(
            loss_type=loss_type,
            feature_space_dimension=feature_space_dimension,
            latent_space_dimension=latent_space_dimension,
            n_res_blocks=n_res_blocks,
            margin_in_loss=margin_in_loss,
            is_train=True)
    # train_step = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(siamese.loss)
    train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        siamese.loss)
    # tf.initialize_all_variables().run()

    saver_ = tf.train.Saver(
        max_to_keep=None
    )  # https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        training_handle = sess.run(training_iterator.string_handle())
        sess.run(training_iterator.initializer)

        if load_saved_network_model:
            succesful_load, latest_epoch = load_network_model(
                saver_=saver_,
                session_=sess,
                checkpoint_dir=path_save_network_model +
                str(which_epoch_to_load_NN_model) + "/",
                model_dir_=model_dir_,
                model_name=deep_model)
            assert (succesful_load == True)
            loss_average_of_epochs = np.load(path_save_loss + "loss.npy")
            loss_average_of_epochs = loss_average_of_epochs[:latest_epoch + 1]
            loss_average_of_epochs = list(loss_average_of_epochs)
        else:
            latest_epoch = -1
            loss_average_of_epochs = []

        for epoch in range(latest_epoch + 1, num_epoch):
            losses_in_epoch = []
            print("============= epoch: " + str(epoch) + "/" +
                  str(num_epoch - 1))
            embeddings_in_epoch = np.zeros(
                (STEPS_PER_EPOCH_TRAIN * batch_size * 3,
                 feature_space_dimension))
            labels_in_epoch = np.zeros(
                (STEPS_PER_EPOCH_TRAIN * batch_size * 3, ))
            latent_space_dimension = 300  #--> see the file ResNet_siamese.py
            embeddings_in_epoch_secondToLast = np.zeros(
                (STEPS_PER_EPOCH_TRAIN * batch_size * 3,
                 latent_space_dimension))
            for i in range(STEPS_PER_EPOCH_TRAIN):
                image_anchor, image_neighbor, image_distant, label_anchor, label_neighbor, label_distant = sess.run(
                    next_element, feed_dict={handle: training_handle})

                image_anchor = image_anchor.reshape(
                    (batch_size, image_height, image_width, image_n_channels))
                image_neighbor = image_neighbor.reshape(
                    (batch_size, image_height, image_width, image_n_channels))
                image_distant = image_distant.reshape(
                    (batch_size, image_height, image_width, image_n_channels))

                _, loss_v, embedding1, embedding2, embedding3, \
                embedding1_secondToLast, embedding2_secondToLast, embedding3_secondToLast = sess.run([train_step, siamese.loss, siamese.o1, siamese.o2, siamese.o3,
                                                                                                    siamese.o1_secondToLast, siamese.o2_secondToLast, siamese.o3_secondToLast], feed_dict={
                                                                                                    siamese.x1: image_anchor,
                                                                                                    siamese.x2: image_neighbor,
                                                                                                    siamese.x3: image_distant})

                embeddings_in_epoch[((i * 3 * batch_size) + (0 * batch_size)):(
                    (i * 3 * batch_size) + (1 * batch_size)), :] = embedding1
                embeddings_in_epoch[((i * 3 * batch_size) + (1 * batch_size)):(
                    (i * 3 * batch_size) + (2 * batch_size)), :] = embedding2
                embeddings_in_epoch[((i * 3 * batch_size) + (2 * batch_size)):(
                    (i * 3 * batch_size) + (3 * batch_size)), :] = embedding3

                labels_in_epoch[((i * 3 * batch_size) + (0 * batch_size)):(
                    (i * 3 * batch_size) + (1 * batch_size))] = label_anchor
                labels_in_epoch[((i * 3 * batch_size) + (1 * batch_size)):(
                    (i * 3 * batch_size) + (2 * batch_size))] = label_neighbor
                labels_in_epoch[((i * 3 * batch_size) + (2 * batch_size)):(
                    (i * 3 * batch_size) + (3 * batch_size))] = label_distant

                embeddings_in_epoch_secondToLast[(
                    (i * 3 * batch_size) + (0 * batch_size)):(
                        (i * 3 * batch_size) +
                        (1 * batch_size)), :] = embedding1_secondToLast
                embeddings_in_epoch_secondToLast[(
                    (i * 3 * batch_size) + (1 * batch_size)):(
                        (i * 3 * batch_size) +
                        (2 * batch_size)), :] = embedding2_secondToLast
                embeddings_in_epoch_secondToLast[(
                    (i * 3 * batch_size) + (2 * batch_size)):(
                        (i * 3 * batch_size) +
                        (3 * batch_size)), :] = embedding3_secondToLast

                losses_in_epoch.extend([loss_v])

            # report average loss of epoch:
            loss_average_of_epochs.append(
                np.average(np.asarray(losses_in_epoch)))
            print("Average loss of epoch " + str(epoch) + ": " +
                  str(loss_average_of_epochs[-1]))
            if not os.path.exists(path_save_loss):
                os.makedirs(path_save_loss)
            np.save(path_save_loss + "loss.npy",
                    np.asarray(loss_average_of_epochs))

            # plot the embedding space:
            if (epoch % save_embedding_every_how_many_epochs == 0):
                if save_points_in_embedding_space:
                    if not os.path.exists(path_save_embedding_space +
                                          "numpy\\"):
                        os.makedirs(path_save_embedding_space + "numpy\\")
                    np.save(
                        path_save_embedding_space +
                        "numpy\\embeddings_in_epoch_" + str(epoch) + ".npy",
                        embeddings_in_epoch)
                    np.save(
                        path_save_embedding_space + "numpy\\labels_in_epoch_" +
                        str(epoch) + ".npy", labels_in_epoch)
                    np.save(
                        path_save_embedding_space +
                        "numpy\\embeddings_in_epoch_secondToLast_" +
                        str(epoch) + ".npy", embeddings_in_epoch_secondToLast)
                if save_plot_embedding_space:
                    print("saving the plot of embedding space....")
                    plt.figure(200)
                    # fig.clf()
                    _, indices_to_plot = plot_embedding_of_points(
                        embeddings_in_epoch, labels_in_epoch, n_samples_plot)
                    if not os.path.exists(path_save_embedding_space +
                                          "plots\\"):
                        os.makedirs(path_save_embedding_space + "plots\\")
                    plt.savefig(path_save_embedding_space + "plots\\" +
                                'epoch' + str(epoch) + '_step' + str(i) +
                                '.png')
                    plt.clf()
                    plt.close()
                    if not os.path.exists(path_save_embedding_space +
                                          "plots_secondToLast\\"):
                        os.makedirs(path_save_embedding_space +
                                    "plots_secondToLast\\")
                    plot_embedding_of_points_secondToLast(
                        embeddings_in_epoch_secondToLast, labels_in_epoch,
                        indices_to_plot)
                    plt.savefig(path_save_embedding_space +
                                "plots_secondToLast\\" + 'epoch' + str(epoch) +
                                '_step' + str(i) + '.png')
                    plt.clf()
                    plt.close()

            # save the network model:
            if (epoch % save_network_model_every_how_many_epochs == 0):
                # save_network_model(saver_=saver_, session_=sess, checkpoint_dir=path_save_network_model, step=epoch, model_name=deep_model, model_dir_=model_dir_)
                save_network_model(saver_=saver_,
                                   session_=sess,
                                   checkpoint_dir=path_save_network_model +
                                   str(epoch) + "/",
                                   step=epoch,
                                   model_name=deep_model,
                                   model_dir_=model_dir_)
                print("Model saved in path: %s" % path_save_network_model)
    def __init__(self,
                 loss_type,
                 feature_space_dimension,
                 n_triplets_per_batch,
                 n_classes,
                 n_samples_per_class_in_batch,
                 n_res_blocks=18,
                 margin_in_loss=0.25,
                 is_train=True,
                 batch_size=32):
        self.img_size_height = dataset_characteristics.get_image_height()
        self.img_size_width = dataset_characteristics.get_image_width()
        self.img_n_channels = dataset_characteristics.get_image_n_channels()
        self.c_dim = 3
        self.res_n = n_res_blocks
        self.feature_space_dimension = feature_space_dimension
        self.margin_in_loss = margin_in_loss
        self.batch_size = batch_size
        self.n_triplets_per_batch = n_triplets_per_batch
        self.n_classes = n_classes
        self.n_triplets_per_batch_per_class = int(
            np.floor(self.n_triplets_per_batch / self.n_classes))
        self.n_samples_per_class_in_batch = n_samples_per_class_in_batch

        self.x1 = tf.placeholder(tf.float32, [
            None, self.img_size_height, self.img_size_width,
            self.img_n_channels
        ])
        self.x1Image = self.x1
        self.labels1 = tf.placeholder(tf.int32, [
            None,
        ])

        self.loss_type = loss_type
        # Create loss
        if is_train:
            with tf.variable_scope("siamese") as scope:
                self.o1 = self.network(self.x1Image,
                                       is_training=True,
                                       reuse=False)
            if self.loss_type == "batch_hard_triplet":
                self.loss = self.batch_hard_triplet_loss(
                    labels=self.labels1,
                    embeddings=self.o1,
                    margin=self.margin_in_loss,
                    squared=True)
            elif self.loss_type == "batch_semi_hard_triplet":
                self.loss = self.batch_semi_hard_triplet_loss(
                    labels=self.labels1,
                    embeddings=self.o1,
                    margin=self.margin_in_loss,
                    squared=True)
            elif self.loss_type == "batch_all_triplet":
                self.loss = self.batch_all_triplet_loss(
                    labels=self.labels1,
                    embeddings=self.o1,
                    margin=self.margin_in_loss,
                    squared=True)
            elif self.loss_type == "Nearest_Nearest_batch_triplet":
                self.loss = self.Nearest_Nearest_batch_triplet_loss(
                    labels=self.labels1,
                    embeddings=self.o1,
                    margin=self.margin_in_loss,
                    squared=True)
            elif self.loss_type == "Nearest_Furthest_batch_triplet":
                self.loss = self.Nearest_Furthest_batch_triplet_loss(
                    labels=self.labels1,
                    embeddings=self.o1,
                    margin=self.margin_in_loss,
                    squared=True)
            elif self.loss_type == "Furthest_Furthest_batch_triplet":
                self.loss = self.Furthest_Furthest_batch_triplet_loss(
                    labels=self.labels1,
                    embeddings=self.o1,
                    margin=self.margin_in_loss,
                    squared=True)
            elif self.loss_type == "Different_distances_batch_triplet":
                self.loss = self.Different_distances_batch_triplet_loss(
                    labels=self.labels1,
                    embeddings=self.o1,
                    margin=self.margin_in_loss,
                    squared=True)
            elif self.loss_type == "Negative_sampling_batch_triplet":
                self.loss = self.Negative_sampling_batch_triplet_loss(
                    labels=self.labels1,
                    embeddings=self.o1,
                    margin=self.margin_in_loss,
                    cutoff=0.5,
                    nonzero_loss_cutoff=1.4,
                    squared=True)
            elif self.loss_type == "NCA_triplet":
                self.loss = self.NCA_triplet_loss(labels=self.labels1,
                                                  embeddings=self.o1,
                                                  squared=True)
            elif self.loss_type == "Proxy_NCA_triplet":
                proxies = Utils_losses.calculate_proxies(
                    n_classes=self.n_classes,
                    feature_space_dimension=self.feature_space_dimension)
                self.loss = self.Proxy_NCA_triplet_loss(labels=self.labels1,
                                                        embeddings=self.o1,
                                                        proxies=proxies,
                                                        squared=True)
            elif self.loss_type == "Proxy_NCA_triplet_CentersAsProxies":
                self.loss = self.Proxy_NCA_triplet_loss_CentersAsProxies(
                    labels=self.labels1, embeddings=self.o1, squared=True)
            elif self.loss_type == "easy_positive_triplet":
                self.loss = self.easy_positive_triplet_loss(
                    labels=self.labels1, embeddings=self.o1, squared=True)
            elif self.loss_type == "easy_positive_triplet_withInnerProduct":
                self.loss = self.easy_positive_triplet_loss_withInnerProduct(
                    labels=self.labels1, embeddings=self.o1, squared=True)
        else:
            with tf.variable_scope("siamese") as scope:
                self.o1 = self.network(self.x1Image,
                                       is_training=False,
                                       reuse=tf.AUTO_REUSE)
    def __init__(self,
                 loss_type,
                 feature_space_dimension,
                 latent_space_dimension,
                 n_res_blocks=18,
                 margin_in_loss=0.25,
                 is_train=True,
                 batch_size=32):
        self.img_size_height = dataset_characteristics.get_image_height()
        self.img_size_width = dataset_characteristics.get_image_width()
        self.img_n_channels = dataset_characteristics.get_image_n_channels()
        self.c_dim = 3
        self.res_n = n_res_blocks
        self.feature_space_dimension = feature_space_dimension
        self.latent_space_dimension = latent_space_dimension
        self.margin_in_loss = margin_in_loss
        self.batch_size = batch_size

        self.x1 = tf.placeholder(tf.float32, [
            None, self.img_size_height, self.img_size_width,
            self.img_n_channels
        ])
        self.x1Image = self.x1
        self.x2 = tf.placeholder(tf.float32, [
            None, self.img_size_height, self.img_size_width,
            self.img_n_channels
        ])
        self.x2Image = self.x2
        self.x3 = tf.placeholder(tf.float32, [
            None, self.img_size_height, self.img_size_width,
            self.img_n_channels
        ])
        self.x3Image = self.x3
        # self.is_train = tf.placeholder(tf.int32, [1])
        # self.weights_lastLayer = tf.placeholder(tf.float32, [None, latent_space_dimension, feature_space_dimension])

        self.loss_type = loss_type
        # Create loss
        if is_train:
            if self.loss_type == "triplet":
                with tf.variable_scope("siamese") as scope:
                    self.o1 = self.network(self.x1Image,
                                           index_in_triplet=1,
                                           is_training=True,
                                           reuse=False)
                    self.o2 = self.network(self.x2Image,
                                           index_in_triplet=2,
                                           is_training=True,
                                           reuse=True)
                    self.o3 = self.network(self.x3Image,
                                           index_in_triplet=3,
                                           is_training=True,
                                           reuse=True)
                self.loss = self.loss_triplet()
            elif self.loss_type == "FDA":
                with tf.variable_scope("siamese") as scope:
                    self.o1 = self.network(self.x1Image,
                                           index_in_triplet=1,
                                           is_training=True,
                                           reuse=False)
                    self.o2 = self.network(self.x2Image,
                                           index_in_triplet=2,
                                           is_training=True,
                                           reuse=True)
                    self.o3 = self.network(self.x3Image,
                                           index_in_triplet=3,
                                           is_training=True,
                                           reuse=True)
                self.get_last_layer_weights()
                self.loss = self.loss_FDA()
            elif self.loss_type == "contrastive":
                with tf.variable_scope("siamese") as scope:
                    self.o1 = self.network(self.x1Image,
                                           index_in_triplet=1,
                                           is_training=True,
                                           reuse=False)
                    self.o2 = self.network(self.x2Image,
                                           index_in_triplet=2,
                                           is_training=True,
                                           reuse=True)
                    self.o3 = self.network(self.x3Image,
                                           index_in_triplet=3,
                                           is_training=True,
                                           reuse=True)
                self.loss = self.loss_contrastive()
            elif self.loss_type == "FDA_contrastive":
                with tf.variable_scope("siamese") as scope:
                    self.o1 = self.network(self.x1Image,
                                           index_in_triplet=1,
                                           is_training=True,
                                           reuse=False)
                    self.o2 = self.network(self.x2Image,
                                           index_in_triplet=2,
                                           is_training=True,
                                           reuse=True)
                    self.o3 = self.network(self.x3Image,
                                           index_in_triplet=3,
                                           is_training=True,
                                           reuse=True)
                self.get_last_layer_weights()
                self.loss = self.loss_FDA_contrastive()
        else:
            if self.loss_type == "triplet":
                with tf.variable_scope("siamese") as scope:
                    self.o1 = self.network(self.x1Image,
                                           index_in_triplet=1,
                                           is_training=False,
                                           reuse=False)
            elif self.loss_type == "FDA":
                pass
def evaluate_embedding_space(path_save_network_model, model_dir_, deep_model,
                             feature_space_dimension, n_res_blocks,
                             margin_in_loss, loss_type, n_triplets_per_batch,
                             n_samples_per_class_in_batch, n_classes,
                             batch_size):
    which_epoch_to_load_NN_model = 50
    batch_size_test = 100
    task_to_do = "classify"  #--> read_into_batches, embed_test_data, classify
    proportions = [0.05, 0.1, 0.25, 0.5, 1]

    path_dataset_test = "D:\\Datasets\\CRC_new_large\\CRC_100K_train_test_numpy\\test2"
    path_save_test_patches = ".\\results\\" + deep_model + "\\batches_test2_set\\"
    path_save_embeddings_of_test_data = ".\\results\\" + deep_model + "\\embedding_test2_set\\"
    path_save_accuracy_of_test_data = ".\\results\\" + deep_model + "\\accuracy_test2_set\\"
    image_height = dataset_characteristics.get_image_height()
    image_width = dataset_characteristics.get_image_width()
    image_n_channels = dataset_characteristics.get_image_n_channels()

    if deep_model == "CNN":
        siamese = CNN_Siamese.CNN_Siamese(
            loss_type=loss_type,
            feature_space_dimension=feature_space_dimension,
            margin_in_loss=margin_in_loss)
    elif deep_model == "ResNet":
        # siamese = ResNet_Siamese.ResNet_Siamese(loss_type=loss_type, feature_space_dimension=feature_space_dimension,
        #                                         n_res_blocks=n_res_blocks, margin_in_loss=margin_in_loss, is_train=True)
        siamese = ResNet_Siamese.ResNet_Siamese(
            loss_type=loss_type,
            feature_space_dimension=feature_space_dimension,
            n_triplets_per_batch=n_triplets_per_batch,
            n_classes=n_classes,
            n_samples_per_class_in_batch=n_samples_per_class_in_batch,
            n_res_blocks=n_res_blocks,
            margin_in_loss=margin_in_loss,
            is_train=True,
            batch_size=batch_size)
    evaluate_ = Evaluate_embedding_space(
        checkpoint_dir=path_save_network_model +
        str(which_epoch_to_load_NN_model) + "/",
        model_dir_=model_dir_,
        deep_model=deep_model,
        batch_size=batch_size_test,
        feature_space_dimension=feature_space_dimension)

    if task_to_do == "read_into_batches":
        paths_of_images = evaluate_.read_batches_paths(
            path_dataset=path_dataset_test,
            path_save_test_patches=path_save_test_patches)
    elif task_to_do == "embed_test_data":
        file = open(path_save_test_patches + 'paths_of_images.pickle', 'rb')
        paths_of_images = pickle.load(file)
        file.close()
        batches, batches_subtypes = evaluate_.read_data_into_batches(
            paths_of_images=paths_of_images)
        embedding, labels = evaluate_.embed_data_in_the_source_domain(
            batches=batches,
            batches_subtypes=batches_subtypes,
            siamese=siamese,
            path_save_embeddings_of_test_data=path_save_embeddings_of_test_data
        )
    elif task_to_do == "classify":
        embedding = np.load(
            ".\\results\\ResNet\\embedding_test2_set\\numpy\\embedding.npy")
        labels = np.load(
            ".\\results\\ResNet\\embedding_test2_set\\numpy\\subtypes.npy")
        evaluate_.classification_in_target_domain_different_data_portions(
            X=embedding,
            y=labels,
            path_save_accuracy_of_test_data=path_save_accuracy_of_test_data,
            proportions=proportions,
            cv=10)
def train_embedding_space(deep_model, n_res_blocks, batch_size, learning_rate,
                          path_save_network_model, model_dir_,
                          feature_space_dimension, margin_in_loss, loss_type,
                          n_triplets_per_batch, n_samples_per_class_in_batch,
                          n_classes):
    #================================ settings:
    Triplet_type = "Different_Distances"  # "Nearest_Nearest", "Nearest_Furthest", "Furthest_Nearest", "Furthest_Furthest", "Different_Distances", "Regular"
    save_plot_embedding_space = True
    save_points_in_embedding_space = True
    load_saved_network_model = False
    which_epoch_to_load_NN_model = 5
    num_epoch = 51
    save_network_model_every_how_many_epochs = 5
    save_embedding_every_how_many_epochs = 5
    n_samples_plot = 2000  #--> if None, plot all
    image_height = dataset_characteristics.get_image_height()
    image_width = dataset_characteristics.get_image_width()
    image_n_channels = dataset_characteristics.get_image_n_channels()
    path_save_embedding_space = ".\\results\\" + deep_model + "\\embedding_train_set\\"
    path_save_loss = ".\\loss_saved\\"
    # path_batches = "D:\\siamese_considering_distance\\codes\\9_create_batches_for_batchBasedMethods\\code\\batches\\"
    # path_base_data_numpy = "D:\\Datasets\\CRC_new_large\\CRC_100K_train_test_numpy\\test1\\"
    path_batches = "C:\\Users\\bghojogh\\Desktop\\code_pathology\\9_create_batches_for_batchBasedMethods\\code\\batches\\"
    path_base_data_numpy = "C:\\Users\\bghojogh\\Desktop\\code_pathology\\data\\"
    #================================

    with open(path_batches + 'batches.pickle', 'rb') as handle:
        loaded_batches_names = pickle.load(handle)
    STEPS_PER_EPOCH_TRAIN = len(
        loaded_batches_names)  #--> must be the number of batches

    # Siamese:
    if deep_model == "CNN":
        siamese = CNN_Siamese.CNN_Siamese(
            loss_type=loss_type,
            feature_space_dimension=feature_space_dimension,
            margin_in_loss=margin_in_loss)
    elif deep_model == "ResNet":
        siamese = ResNet_Siamese.ResNet_Siamese(
            loss_type=loss_type,
            feature_space_dimension=feature_space_dimension,
            n_triplets_per_batch=n_triplets_per_batch,
            n_classes=n_classes,
            n_samples_per_class_in_batch=n_samples_per_class_in_batch,
            n_res_blocks=n_res_blocks,
            margin_in_loss=margin_in_loss,
            is_train=True,
            batch_size=batch_size)
    # train_step = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(siamese.loss)
    train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        siamese.loss)
    # tf.initialize_all_variables().run()

    saver_ = tf.train.Saver(
        max_to_keep=None
    )  # https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        if load_saved_network_model:
            succesful_load, latest_epoch = load_network_model(
                saver_=saver_,
                session_=sess,
                checkpoint_dir=path_save_network_model +
                str(which_epoch_to_load_NN_model) + "/",
                model_dir_=model_dir_,
                model_name=deep_model)
            assert (succesful_load == True)
            loss_average_of_epochs = np.load(path_save_loss + "loss.npy")
            loss_average_of_epochs = loss_average_of_epochs[:latest_epoch + 1]
            loss_average_of_epochs = list(loss_average_of_epochs)
        else:
            latest_epoch = -1
            loss_average_of_epochs = []

        for epoch in range(latest_epoch + 1, num_epoch):
            losses_in_epoch = []
            print("============= epoch: " + str(epoch) + "/" +
                  str(num_epoch - 1))
            embeddings_in_epoch = np.zeros(
                (STEPS_PER_EPOCH_TRAIN * batch_size, feature_space_dimension))
            labels_in_epoch = np.zeros((STEPS_PER_EPOCH_TRAIN * batch_size, ))
            for i in range(STEPS_PER_EPOCH_TRAIN):
                if i % 10 == 0:
                    print("STEPS_PER_EPOCH_TRAIN " + str(i) + "/" +
                          str(STEPS_PER_EPOCH_TRAIN) + "...")

                loaded_batch, loaded_labels = read_batches_data(
                    loaded_batch_names=loaded_batches_names[i],
                    batch_size=batch_size,
                    path_base_data_numpy=path_base_data_numpy)

                loaded_batch = loaded_batch.reshape(
                    (batch_size, image_height, image_width, image_n_channels))

                _, loss_v, embedding1 = sess.run(
                    [train_step, siamese.loss, siamese.o1],
                    feed_dict={
                        siamese.x1: loaded_batch,
                        siamese.labels1: loaded_labels
                    })

                embeddings_in_epoch[((i * batch_size) + (0 * batch_size)):(
                    (i * batch_size) + (1 * batch_size)), :] = embedding1

                labels_in_epoch[((i * batch_size) + (0 * batch_size)):(
                    (i * batch_size) + (1 * batch_size))] = loaded_labels

                losses_in_epoch.extend([loss_v])

            # report average loss of epoch:
            loss_average_of_epochs.append(
                np.average(np.asarray(losses_in_epoch)))
            print("Average loss of epoch " + str(epoch) + ": " +
                  str(loss_average_of_epochs[-1]))
            if not os.path.exists(path_save_loss):
                os.makedirs(path_save_loss)
            np.save(path_save_loss + "loss.npy",
                    np.asarray(loss_average_of_epochs))

            # plot the embedding space:
            if (epoch % save_embedding_every_how_many_epochs == 0):
                if save_points_in_embedding_space:
                    if not os.path.exists(path_save_embedding_space +
                                          "numpy\\"):
                        os.makedirs(path_save_embedding_space + "numpy\\")
                    np.save(
                        path_save_embedding_space +
                        "numpy\\embeddings_in_epoch_" + str(epoch) + ".npy",
                        embeddings_in_epoch)
                    np.save(
                        path_save_embedding_space + "numpy\\labels_in_epoch_" +
                        str(epoch) + ".npy", labels_in_epoch)
                if save_plot_embedding_space:
                    print("saving the plot of embedding space....")
                    plt.figure(200)
                    # fig.clf()
                    _, indices_to_plot = plot_embedding_of_points(
                        embeddings_in_epoch, labels_in_epoch, n_samples_plot)
                    if not os.path.exists(path_save_embedding_space +
                                          "plots\\"):
                        os.makedirs(path_save_embedding_space + "plots\\")
                    plt.savefig(path_save_embedding_space + "plots\\" +
                                'epoch' + str(epoch) + '_step' + str(i) +
                                '.png')
                    plt.clf()
                    plt.close()

            # save the network model:
            if (epoch % save_network_model_every_how_many_epochs == 0):
                # save_network_model(saver_=saver_, session_=sess, checkpoint_dir=path_save_network_model, step=epoch, model_name=deep_model, model_dir_=model_dir_)
                save_network_model(saver_=saver_,
                                   session_=sess,
                                   checkpoint_dir=path_save_network_model +
                                   str(epoch) + "/",
                                   step=epoch,
                                   model_name=deep_model,
                                   model_dir_=model_dir_)
                print("Model saved in path: %s" % path_save_network_model)