Ejemplo n.º 1
0
    def __init__(self, path="./data/"):
        self.__path = path
        self.__train_folder = "train_alphabets"
        self.__test_folder = "test_alphabets"

        self.__data_loader = OmniglotLoader(self.__path, self.__train_folder,
                                            self.__test_folder)

        self.train_set = ()
        self.val_set = ()
        self.test_set = ()
        self.data_shape = ()
    def __init__(self, dataset_path, learning_rate, batch_size,
                 use_augmentation, learning_rate_multipliers,
                 l2_regularization_penalization, tensorboard_log_path):
        """Inits SiameseNetwork with the provided values for the attributes.

        It also constructs the siamese network architecture, creates a dataset 
        loader and opens the log file.

        Arguments:
            dataset_path: path of Omniglot dataset    
            learning_rate: SGD learning rate
            batch_size: size of the batch to be used in training
            use_augmentation: boolean that allows us to select if data augmentation 
                is used or not
            learning_rate_multipliers: learning-rate multipliers (relative to the learning_rate
                chosen) that will be applied to each fo the conv and dense layers
                for example:
                    # Setting the Learning rate multipliers
                    LR_mult_dict = {}
                    LR_mult_dict['conv1']=1
                    LR_mult_dict['conv2']=1
                    LR_mult_dict['dense1']=2
                    LR_mult_dict['dense2']=2
            l2_regularization_penalization: l2 penalization for each layer.
                for example:
                    # Setting the Learning rate multipliers
                    L2_dictionary = {}
                    L2_dictionary['conv1']=0.1
                    L2_dictionary['conv2']=0.001
                    L2_dictionary['dense1']=0.001
                    L2_dictionary['dense2']=0.01
            tensorboard_log_path: path to store the logs                
        """
        use_CNN = True
        if use_CNN:
            self.input_shape = ((32, 24, 1))
        else:
            self.input_shape = (
                (768, ))  # Size of sentences vector 长度是经过bert转化后的768
        self.model = []
        self.learning_rate = learning_rate
        self.omniglot_loader = OmniglotLoader(
            dataset_path=dataset_path,
            use_augmentation=use_augmentation,
            batch_size=batch_size)
        self.summary_writer = tf.summary.FileWriter(tensorboard_log_path)
        self._construct_siamese_architecture(learning_rate_multipliers,
                                             l2_regularization_penalization)
    def __init__(self, dataset_path,  learning_rate, batch_size, use_augmentation,
                 learning_rate_multipliers, l2_regularization_penalization, tensorboard_log_path):
        """Inits SiameseNetwork with the provided values for the attributes.

        It also constructs the siamese network architecture, creates a dataset 
        loader and opens the log file.
        """

        self.input_shape = (1,128)  # Size of images
        self.model = []
        self.learning_rate = learning_rate
        self.omniglot_loader = OmniglotLoader(
            dataset_path=dataset_path, use_augmentation=use_augmentation, batch_size=batch_size)
        self.summary_writer = tf.summary.FileWriter(tensorboard_log_path)
        self._construct_siamese_architecture(learning_rate_multipliers,
                                              l2_regularization_penalization)
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import scikitplot as skplt
import sklearn

from omniglot_loader import OmniglotLoader
from siamese_network import SiameseNetwork

model_wo_tf = './trained_models/wo_transform/model.h5'
model_w_tf = './trained_models/w_transform/model.h5'

if __name__ == '__main__':

    # First test case, training without transformations, testing without transformations
    omg = OmniglotLoader(use_transformations=False)
    network = SiameseNetwork(model_location=model_wo_tf)

    y_pred, te_y = network.get_predictions(omg)
    y_pred = np.array(y_pred)
    te_y = np.array(te_y)
    fpr, tpr, _ = skplt.metrics.roc_curve(te_y.flatten(), y_pred.flatten())
    auc = sklearn.metrics.roc_auc_score(te_y.flatten(), y_pred.flatten())

    # Second test case, training without transformations, testing with transformations
    omg = OmniglotLoader(use_transformations=True)
    y_pred, te_y = network.get_predictions(omg)
    y_pred = np.array(y_pred)
    te_y = np.array(te_y)
    fpr_te_tf, tpr_te_tf, _ = skplt.metrics.roc_curve(te_y.flatten(),
                                                      y_pred.flatten())
Ejemplo n.º 5
0
#coding:utf-8
"""孪生网络
"""
from __future__ import print_function
import tensorflow as tf
import numpy as np
from omniglot_loader import OmniglotLoader
import sys

dataset_path = sys.argv[1]

omniglot_loader = OmniglotLoader(
    dataset_path=dataset_path, use_augmentation=True, batch_size=32)
omniglot_loader.split_train_datasets()

images, labels = omniglot_loader.get_train_batch()

print([i.shape for i in images], labels.shape)

images, labels = omniglot_loader.get_one_shot_batch(32, False)

print([i.shape for i in images], labels.shape)

from keras.models import Sequential, Model
from keras.layers import Conv2D, Dense, MaxPool2D, GlobalAveragePooling2D, BatchNormalization, Activation, Input, Lambda
from keras.regularizers import l2
from keras.optimizers import Adam
import keras.backend as K


input_shape = (105, 105, 1)