Exemplo n.º 1
0
class ExperimentBuilder:

    def __init__(self, data):
        """
        Initializes an ExperimentBuilder object. The ExperimentBuilder object takes care of setting up our experiment
        and provides helper functions such as run_training_epoch and run_validation_epoch to simplify out training
        and evaluation procedures.
        :param data: A data provider class
        """
        self.data = data

    def build_experiment(self, batch_size, classes_per_set, samples_per_class, fce, args, full_context_unroll_k=5,
                         num_gpus=1, data_augmentation=True):

        """

        :param batch_size: The experiment batch size
        :param classes_per_set: An integer indicating the number of classes per support set
        :param samples_per_class: An integer indicating the number of samples per class
        :param channels: The image channels
        :param fce: Whether to use full context embeddings or not
        :return: a matching_network object, along with the losses, the training ops and the init op
        """
        height, width, channels = self.data.dataset.image_height, self.data.dataset.image_width, \
                                  self.data.dataset.image_channel  # missing
        self.support_set_images = tf.placeholder(tf.float32,
                                                 [num_gpus, batch_size, classes_per_set, samples_per_class, height,
                                                  width,
                                                  channels], 'support_set_images')
        self.support_set_labels = tf.placeholder(tf.int32, [num_gpus, batch_size, classes_per_set, samples_per_class],
                                                 'support_set_labels')
        self.target_image = tf.placeholder(tf.float32, [num_gpus, batch_size, height, width, channels], 'target_image')
        self.target_label = tf.placeholder(tf.int32, [num_gpus, batch_size], 'target_label')
        self.training_phase = tf.placeholder(tf.bool, name='training-flag')
        self.dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')
        self.current_learning_rate = 1e-03
        self.learning_rate = tf.placeholder(tf.float32, name='learning-rate-set')
        self.args = args
        self.one_shot_omniglot = MatchingNetwork(batch_size=batch_size, support_set_images=self.support_set_images,
                                                 support_set_labels=self.support_set_labels,
                                                 target_image=self.target_image, target_label=self.target_label,
                                                 dropout_rate=self.dropout_rate, num_channels=channels,
                                                 is_training=self.training_phase, fce=fce,
                                                 num_classes_per_set=classes_per_set,
                                                 num_samples_per_class=samples_per_class,
                                                 learning_rate=self.learning_rate,
                                                 full_context_unroll_k=full_context_unroll_k)
        self.data_augmentation = data_augmentation
        summary, self.losses, self.c_error_opt_op = self.one_shot_omniglot.init_train()
        init = tf.global_variables_initializer()
        self.total_train_iter = 0
        return self.one_shot_omniglot, self.losses, self.c_error_opt_op, init

    def run_training_epoch(self, total_train_batches, sess):
        """
        Runs one training epoch
        :param total_train_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_training_categorical_crossentropy_loss and mean_training_accuracy
        """
        total_train_c_loss = []
        total_train_accuracy = []
        with tqdm.tqdm(total=total_train_batches) as pbar:
            for sample_id, train_sample in enumerate(self.data.get_train_batches(total_batches=total_train_batches,
                                                                                 augment_images=self.data_augmentation)):

                support_set_images, target_set_image, support_set_labels, target_set_label = train_sample

                _, c_loss_value, acc = sess.run(
                    [self.c_error_opt_op, self.losses[self.one_shot_omniglot.classify],
                     self.losses[self.one_shot_omniglot.dn]],
                    feed_dict={self.dropout_rate: self.args.dropout_rate_value,
                               self.support_set_images: support_set_images[0],
                               self.support_set_labels: support_set_labels[0], self.target_image: target_set_image[0],
                               self.target_label: target_set_label[0], self.training_phase: True,
                               self.learning_rate: self.current_learning_rate})

                iter_out = "train_loss: {:.6f}, train_accuracy: {:.3f}".format(c_loss_value, acc)
                pbar.set_description(iter_out)

                pbar.update(1)
                total_train_c_loss.append(c_loss_value)
                total_train_accuracy.append(acc)
                self.total_train_iter += 1
                if self.total_train_iter % 2000 == 0:
                    self.current_learning_rate /= 2
                    print("change learning rate", self.current_learning_rate)

        total_train_c_loss_mean = np.mean(total_train_c_loss)
        total_train_c_loss_std = np.std(total_train_c_loss)

        total_train_accuracy_mean = np.mean(total_train_accuracy)
        total_train_accuracy_std = np.std(total_train_accuracy)
        return total_train_c_loss_mean, total_train_c_loss_std, total_train_accuracy_mean, total_train_accuracy_std

    def run_validation_epoch(self, total_val_batches, sess):
        """
        Runs one validation epoch
        :param total_val_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_validation_categorical_crossentropy_loss and mean_validation_accuracy
        """
        total_val_c_loss = []
        total_val_accuracy = []
        with tqdm.tqdm(total=total_val_batches) as pbar:
            for sample_id, val_sample in enumerate(self.data.get_val_batches(total_batches=total_val_batches,
                                                                             augment_images=False)):
                support_set_images, target_set_image, support_set_labels, target_set_label = val_sample

                c_loss_value, acc = sess.run(
                    [self.losses[self.one_shot_omniglot.classify],
                     self.losses[self.one_shot_omniglot.dn]],
                    feed_dict={self.dropout_rate: self.args.dropout_rate_value,
                               self.support_set_images: support_set_images[0],
                               self.support_set_labels: support_set_labels[0], self.target_image: target_set_image[0],
                               self.target_label: target_set_label[0], self.training_phase: False,
                               self.learning_rate: self.current_learning_rate})

                iter_out = "val_loss: {:.6f}, val_accuracy: {:.3f}".format(c_loss_value, acc)
                pbar.set_description(iter_out)

                pbar.update(1)
                total_val_c_loss.append(c_loss_value)
                total_val_accuracy.append(acc)

        total_val_c_loss_mean = np.mean(total_val_c_loss)
        total_val_c_loss_std = np.std(total_val_c_loss)

        total_val_accuracy_mean = np.mean(total_val_accuracy)
        total_val_accuracy_std = np.std(total_val_accuracy)
        return total_val_c_loss_mean, total_val_c_loss_std, total_val_accuracy_mean, total_val_accuracy_std

    def run_testing_epoch(self, total_test_batches, sess):
        """
        Runs one testing epoch
        :param total_test_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_testing_categorical_crossentropy_loss and mean_testing_accuracy
        """
        total_test_c_loss = []
        total_test_accuracy = []
        with tqdm.tqdm(total=total_test_batches) as pbar:
            for sample_id, test_sample in enumerate(self.data.get_test_batches(total_batches=total_test_batches,
                                                                               augment_images=False)):
                support_set_images, target_set_image, support_set_labels, target_set_label = test_sample

                c_loss_value, acc = sess.run(
                    [self.losses[self.one_shot_omniglot.classify],
                     self.losses[self.one_shot_omniglot.dn]],
                    feed_dict={self.dropout_rate: self.args.dropout_rate_value,
                               self.support_set_images: support_set_images[0],
                               self.support_set_labels: support_set_labels[0], self.target_image: target_set_image[0],
                               self.target_label: target_set_label[0], self.training_phase: False,
                               self.learning_rate: self.current_learning_rate})

                iter_out = "test_loss: {:.6f}, test_accuracy: {:.3f}".format(c_loss_value, acc)
                pbar.set_description(iter_out)

                pbar.update(1)
                total_test_c_loss.append(c_loss_value)
                total_test_accuracy.append(acc)

        total_test_c_loss_mean = np.mean(total_test_c_loss)
        total_test_c_loss_std = np.std(total_test_c_loss)

        total_test_accuracy_mean = np.mean(total_test_accuracy)
        total_test_accuracy_std = np.std(total_test_accuracy)
        return total_test_c_loss_mean, total_test_c_loss_std, total_test_accuracy_mean, total_test_accuracy_std
Exemplo n.º 2
0
class ExperimentBuilder:
    def __init__(self, data):
        """
        Initializes an ExperimentBuilder object. The ExperimentBuilder object takes care of setting up our experiment
        and provides helper functions such as run_training_epoch and run_validation_epoch to simplify out training
        and evaluation procedures.
        :param data: A data provider class
        """
        self.data = data

    def build_experiment(self, batch_size, classes_per_set, samples_per_class,
                         fce):
        """

        :param batch_size: The experiment batch size
        :param classes_per_set: An integer indicating the number of classes per support set
        :param samples_per_class: An integer indicating the number of samples per class
        :param channels: The image channels
        :param fce: Whether to use full context embeddings or not
        :return: a matching_network object, along with the losses, the training ops and the init op
        """
        # height, width, channels = self.data.x_train.shape[2], self.data.x_train.shape[3], self.data.x_train.shape[4]
        feature = self.data.x_train.shape[2]
        self.support_set_images = tf.placeholder(
            tf.float32,
            [batch_size, classes_per_set, samples_per_class, feature],
            'support_set_images')
        self.support_set_labels = tf.placeholder(
            tf.int32, [batch_size, classes_per_set, samples_per_class],
            'support_set_labels')
        self.target_image = tf.placeholder(tf.float32, [batch_size, feature],
                                           'target_image')  #'target_vector'
        self.target_label = tf.placeholder(tf.int32, [batch_size],
                                           'target_label')
        self.training_phase = tf.placeholder(tf.bool, name='training-flag')
        self.rotate_flag = tf.placeholder(tf.bool, name='rotate-flag')
        self.keep_prob = tf.placeholder(tf.float32, name='dropout-prob')
        self.current_learning_rate = 1e-03
        self.learning_rate = tf.placeholder(tf.float32,
                                            name='learning-rate-set')
        self.one_shot_omniglot = MatchingNetwork(
            batch_size=batch_size,
            support_set_images=self.support_set_images,
            support_set_labels=self.support_set_labels,
            target_image=self.target_image,
            target_label=self.target_label,
            keep_prob=self.keep_prob,
            is_training=self.training_phase,
            fce=fce,
            rotate_flag=self.rotate_flag,
            num_classes_per_set=classes_per_set,
            num_samples_per_class=samples_per_class,
            learning_rate=self.learning_rate)

        summary, self.losses, self.c_error_opt_op = self.one_shot_omniglot.init_train(
        )
        init = tf.global_variables_initializer()
        self.total_train_iter = 0
        return self.one_shot_omniglot, self.losses, self.c_error_opt_op, init

    def run_training_epoch(self, total_train_batches, sess):
        """
        Runs one training epoch
        :param total_train_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_training_categorical_crossentropy_loss and mean_training_accuracy
        """
        total_c_loss = 0.
        total_accuracy = 0.
        # summary_writer = tf.summary.FileWriter(summary_path, sess.graph)

        with tqdm.tqdm(total=total_train_batches) as pbar:

            for i in range(total_train_batches):  # train epoch
                x_support_set, y_support_set, x_target, y_target = self.data.get_train_batch(
                    augment=False)
                _, c_loss_value, acc = sess.run(
                    [
                        self.c_error_opt_op,
                        self.losses[self.one_shot_omniglot.classify],
                        self.losses[self.one_shot_omniglot.dn]
                    ],
                    feed_dict={
                        self.keep_prob: 0.5,
                        self.support_set_images: x_support_set,
                        self.support_set_labels: y_support_set,
                        self.target_image: x_target,
                        self.target_label: y_target,
                        self.training_phase: True,
                        self.rotate_flag: False,
                        self.learning_rate: self.current_learning_rate
                    })

                iter_out = "train_loss: {}, train_accuracy: {}".format(
                    c_loss_value, acc)
                pbar.set_description(iter_out)

                pbar.update(1)
                total_c_loss += c_loss_value
                total_accuracy += acc
                self.total_train_iter += 1
                if self.total_train_iter % 2000 == 0:
                    self.current_learning_rate /= 2
                    print("change learning rate", self.current_learning_rate)

        total_c_loss = total_c_loss / total_train_batches
        total_accuracy = total_accuracy / total_train_batches
        return total_c_loss, total_accuracy

    def run_validation_epoch(self, total_val_batches, sess):
        """
        Runs one validation epoch
        :param total_val_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_validation_categorical_crossentropy_loss and mean_validation_accuracy
        """
        total_val_c_loss = 0.
        total_val_accuracy = 0.
        # summary_writer = tf.summary.FileWriter(summary_path, sess.graph)

        with tqdm.tqdm(total=total_val_batches) as pbar:
            for i in range(total_val_batches):  # validation epoch
                x_support_set, y_support_set, x_target, y_target = self.data.get_val_batch(
                    augment=False)
                c_loss_value, acc = sess.run(
                    [
                        self.losses[self.one_shot_omniglot.classify],
                        self.losses[self.one_shot_omniglot.dn]
                    ],
                    feed_dict={
                        self.keep_prob: 1.0,
                        self.support_set_images: x_support_set,
                        self.support_set_labels: y_support_set,
                        self.target_image: x_target,
                        self.target_label: y_target,
                        self.training_phase: False,
                        self.rotate_flag: False
                    })

                iter_out = "val_loss: {}, val_accuracy: {}".format(
                    c_loss_value, acc)
                pbar.set_description(iter_out)
                pbar.update(1)

                total_val_c_loss += c_loss_value
                total_val_accuracy += acc

        total_val_c_loss = total_val_c_loss / total_val_batches
        total_val_accuracy = total_val_accuracy / total_val_batches

        return total_val_c_loss, total_val_accuracy

    def run_testing_epoch(self, total_test_batches, sess):
        """
        Runs one testing epoch
        :param total_test_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_testing_categorical_crossentropy_loss and mean_testing_accuracy
        """
        total_test_c_loss = 0.
        total_test_accuracy = 0.
        # summary_writer = tf.summary.FileWriter(summary_path, sess.graph)

        with tqdm.tqdm(total=total_test_batches) as pbar:
            for i in range(total_test_batches):
                x_support_set, y_support_set, x_target, y_target = self.data.get_test_batch(
                    augment=False)
                c_loss_value, acc = sess.run(
                    [
                        self.losses[self.one_shot_omniglot.classify],
                        self.losses[self.one_shot_omniglot.dn]
                    ],
                    feed_dict={
                        self.keep_prob: 1.0,
                        self.support_set_images: x_support_set,
                        self.support_set_labels: y_support_set,
                        self.target_image: x_target,
                        self.target_label: y_target,
                        self.training_phase: False,
                        self.rotate_flag: False
                    })

                iter_out = "test_loss: {}, test_accuracy: {}".format(
                    c_loss_value, acc)
                pbar.set_description(iter_out)
                pbar.update(1)

                total_test_c_loss += c_loss_value
                total_test_accuracy += acc
            total_test_c_loss = total_test_c_loss / total_test_batches
            total_test_accuracy = total_test_accuracy / total_test_batches
        return total_test_c_loss, total_test_accuracy