Example #1
0
def random_classification_datasets(n_samples,
                                   features=100,
                                   classes=2,
                                   informative=.1,
                                   partition_proportions=(.5, .3),
                                   rnd=None,
                                   one_hot=True,
                                   **mk_cls_kwargs):
    rnd_state = em.get_rand_state(rnd)
    X, Y = make_classification(n_samples,
                               features,
                               n_classes=classes,
                               random_state=rnd_state,
                               **mk_cls_kwargs)
    if one_hot:
        Y = utils.to_one_hot_enc(Y)

    print('range of Y', np.min(Y), np.max(Y))
    info = utils.merge_dicts({
        'informative': informative,
        'random_seed': rnd
    }, mk_cls_kwargs)
    name = em.utils.name_from_dict(info, 'w')
    dt = em.Dataset(X, Y, name=name, info=info)
    datasets = em.Datasets.from_list(redivide_data([dt],
                                                   partition_proportions))
    print('conditioning of X^T X',
          np.linalg.cond(datasets.train.data.T @ datasets.train.data))
    return datasets
Example #2
0
def random_regression_datasets(n_samples,
                               features=100,
                               outs=1,
                               informative=.1,
                               partition_proportions=(.5, .3),
                               rnd=None,
                               **mk_rgr_kwargs):
    rnd_state = em.get_rand_state(rnd)
    X, Y, w = make_regression(n_samples,
                              features,
                              int(features * informative),
                              outs,
                              random_state=rnd_state,
                              coef=True,
                              **mk_rgr_kwargs)
    if outs == 1:
        Y = np.reshape(Y, (n_samples, 1))

    print('range of Y', np.min(Y), np.max(Y))
    info = utils.merge_dicts(
        {
            'informative': informative,
            'random_seed': rnd,
            'w': w
        }, mk_rgr_kwargs)
    name = em.utils.name_from_dict(info, 'w')
    dt = em.Dataset(X, Y, name=name, info=info)
    datasets = em.Datasets.from_list(redivide_data([dt],
                                                   partition_proportions))
    print('conditioning of X^T X',
          np.linalg.cond(datasets.train.data.T @ datasets.train.data))
    return datasets
Example #3
0
        def generate_datasets(self,
                              rand=None,
                              num_classes=None,
                              num_examples=None,
                              wait_for_n_min=None):

            rand = em.get_rand_state(rand)

            if wait_for_n_min:
                import time
                while not self.check_loaded_images(wait_for_n_min):
                    time.sleep(5)

            if not num_examples: num_examples = self.kwargs['num_examples']
            if not num_classes: num_classes = self.kwargs['num_classes']

            clss = self._loaded_images if self._loaded_images else self.info[
                'classes']

            random_classes = rand.choice(list(clss.keys()),
                                         size=(num_classes, ),
                                         replace=False)
            rand_class_dict = {rnd: k for k, rnd in enumerate(random_classes)}

            _dts = []
            for ns in em.as_tuple_or_list(num_examples):
                classes = balanced_choice_wr(random_classes, ns, rand)

                all_images = {cls: list(clss[cls]) for cls in classes}
                data, targets, sample_info = [], [], []
                for c in classes:
                    rand.shuffle(all_images[c])
                    img_name = all_images[c][0]
                    all_images[c].remove(img_name)
                    sample_info.append({'name': img_name, 'label': c})

                    if self._loaded_images:
                        data.append(clss[c][img_name])
                    else:
                        from scipy.misc import imread, imresize
                        data.append(
                            imresize(imread(join(self.info['base_folder'],
                                                 join(c, img_name)),
                                            mode='RGB'),
                                     size=(self.info['resize'],
                                           self.info['resize'], 3)) / 255.)
                    targets.append(rand_class_dict[c])

                if self.info['one_hot_enc']:
                    targets = em.to_one_hot_enc(targets, dimension=num_classes)

                _dts.append(
                    em.Dataset(data=np.array(np.stack(data)),
                               target=targets,
                               sample_info=sample_info,
                               info={'all_classes': random_classes}))
            return em.Datasets.from_list(_dts)
Example #4
0
        def generate_datasets(self,
                              rand=None,
                              num_classes=None,
                              num_examples=None):
            rand = em.get_rand_state(rand)

            if not num_examples: num_examples = self.kwargs['num_examples']
            if not num_classes: num_classes = self.kwargs['num_classes']

            clss = self._loaded_images if self._loaded_images else self.info[
                'classes']

            random_classes = rand.choice(list(clss.keys()),
                                         size=(num_classes, ),
                                         replace=False)
            rand_class_dict = {rnd: k for k, rnd in enumerate(random_classes)}

            _dts = []
            for ns in em.as_tuple_or_list(num_examples):
                classes = balanced_choice_wr(random_classes, ns, rand)

                all_images = {cls: list(clss[cls]) for cls in classes}
                data, targets, sample_info = [], [], []
                for c in classes:
                    rand.shuffle(all_images[c])
                    img_name = all_images[c][0]
                    all_images[c].remove(img_name)
                    sample_info.append({'name': img_name, 'label': c})
                    data.append(clss[c][img_name])
                    targets.append(rand_class_dict[c])

                if self.info['one_hot_enc']:
                    targets = em.to_one_hot_enc(targets, dimension=num_classes)

                _dts.append(
                    em.Dataset(data=np.array(np.stack(data)),
                               target=targets,
                               sample_info=sample_info,
                               info={'all_classes': random_classes}))
            return em.Datasets.from_list(_dts)
Example #5
0
def train(metasets,
          ex_name,
          hyper_repr_model_builder,
          classifier_builder=None,
          saver=None,
          seed=0,
          MBS=4,
          available_devices=('/gpu:0', '/gpu:1'),
          mlr0=.001,
          mlr_decay=1.e-5,
          T=4,
          n_episodes_testing=600,
          print_every=1000,
          patience=40,
          restore_model=False,
          lr=0.1,
          learn_lr=True,
          process_fn=None):
    """
    Function for training an hyper-representation network.

    :param metasets: Datasets of MetaDatasets
    :param ex_name: name of the experiment
    :param hyper_repr_model_builder: builder for the representation model
    :param classifier_builder: optional builder for classifier model (if None then builds a linear model)
    :param saver: experiment_manager.Saver object
    :param seed:
    :param MBS: meta-batch size
    :param available_devices: distribute the computation among different GPUS!
    :param mlr0: initial meta learning rate
    :param mlr_decay:
    :param T: number of gradient steps for training ground models
    :param n_episodes_testing:
    :param print_every:
    :param patience:
    :param restore_model:
    :param lr: initial ground models learning rate
    :param learn_lr: True for optimizing the ground models learning rate
    :param process_fn: optinal hypergradient process function (like gradient clipping)

    :return: tuple: the saver object, the hyper-representation model and the list of experiments objects
    """
    if saver is None:
        saver = SAVER_EXP(metasets)

    T, ss, n_episodes_testing = setup(T, seed, n_episodes_testing, MBS)
    exs = [em.SLExperiment(metasets) for _ in range(MBS)]

    hyper_repr_model = hyper_repr_model_builder(exs[0].x, name=ex_name)
    if classifier_builder is None:
        classifier_builder = lambda inp, name: models.FeedForwardNet(
            inp, metasets.train.dim_target, name=name)

    io_optim, gs, meta_lr, oo_optim, farho = _optimizers(
        lr, mlr0, mlr_decay, learn_lr)

    for k, ex in enumerate(exs):
        with tf.device(available_devices[k % len(available_devices)]):
            ex.model = classifier_builder(
                hyper_repr_model.for_input(ex.x).out, 'Classifier_%s' % k)
            ex.errors['training'] = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(labels=ex.y,
                                                        logits=ex.model.out))
            ex.errors['validation'] = ex.errors['training']
            ex.scores['accuracy'] = tf.reduce_mean(tf.cast(
                tf.equal(tf.argmax(ex.y, 1), tf.argmax(ex.model.out, 1)),
                tf.float32),
                                                   name='accuracy')

            optim_dict = farho.inner_problem(ex.errors['training'],
                                             io_optim,
                                             var_list=ex.model.var_list)
            farho.outer_problem(ex.errors['validation'],
                                optim_dict,
                                oo_optim,
                                global_step=gs)

    farho.finalize(process_fn=process_fn)

    feed_dicts, just_train_on_dataset, mean_acc_on, cond = _helper_function(
        exs, n_episodes_testing, MBS, ss, farho, T)

    rand = em.get_rand_state(0)

    with saver.record(*_records(metasets, saver, hyper_repr_model, cond, ss,
                                mean_acc_on, ex_name, meta_lr),
                      where='far',
                      every=print_every,
                      append_string=ex_name):
        tf.global_variables_initializer().run()
        if restore_model:
            saver.restore_model(hyper_repr_model)
        # ADD ONLY TESTING
        for _ in cond.early_stopping_sv(saver, patience):
            trfd, vfd = feed_dicts(
                metasets.train.generate_batch(MBS, rand=rand))

            farho.run(
                T[0], trfd, vfd
            )  # one iteration of optimization of representation variables (hyperparameters)

    return saver, hyper_repr_model, exs
Example #6
0
def meta_omniglot_v2(folder=OMNIGLOT_RESIZED,
                     std_num_classes=None,
                     std_num_examples=None,
                     one_hot_enc=True,
                     _rand=0,
                     n_splits=None):
    """
    Loading function for Omniglot dataset in learning-to-learn version. Use image data as obtained from
    https://github.com/cbfinn/maml/blob/master/data/omniglot_resized/resize_images.py

    :param folder: root folder name.
    :param std_num_classes: standard number of classes for N-way classification
    :param std_num_examples: standard number of examples (e.g. for 1-shot 5-way should be 5)
    :param one_hot_enc: one hot encoding
    :param _rand: random seed or RandomState for generate training, validation, testing meta-datasets
                    split
    :param n_splits: num classes per split
    :return: a Datasets of MetaDataset s
    """
    class OmniglotMetaDataset(em.MetaDataset):
        def __init__(self,
                     info=None,
                     rotations=None,
                     name='Omniglot',
                     num_classes=None,
                     num_examples=None):
            super().__init__(info,
                             name=name,
                             num_classes=num_classes,
                             num_examples=num_examples)
            self._loaded_images = defaultdict(lambda: {})
            self._rotations = rotations or [0, 90, 180, 270]
            self._img_array = None
            self.load_all()

        def generate_datasets(self,
                              rand=None,
                              num_classes=None,
                              num_examples=None):
            rand = em.get_rand_state(rand)

            if not num_examples: num_examples = self.kwargs['num_examples']
            if not num_classes: num_classes = self.kwargs['num_classes']

            clss = self._loaded_images if self._loaded_images else self.info[
                'classes']

            random_classes = rand.choice(list(clss.keys()),
                                         size=(num_classes, ),
                                         replace=False)
            rand_class_dict = {rnd: k for k, rnd in enumerate(random_classes)}

            _dts = []
            for ns in em.as_tuple_or_list(num_examples):
                classes = balanced_choice_wr(random_classes, ns, rand)

                all_images = {cls: list(clss[cls]) for cls in classes}
                indices, targets = [], []
                for c in classes:
                    rand.shuffle(all_images[c])
                    img_name = all_images[c][0]
                    all_images[c].remove(img_name)
                    # sample_info.append({'name': img_name, 'label': c})
                    indices.append(clss[c][img_name])
                    targets.append(rand_class_dict[c])

                if self.info['one_hot_enc']:
                    targets = em.to_one_hot_enc(targets, dimension=num_classes)

                data = self._img_array[indices]

                _dts.append(em.Dataset(data=data, target=targets))
            return em.Datasets.from_list(_dts)

        def load_all(self):
            from scipy.ndimage import imread
            from scipy.ndimage.interpolation import rotate
            _cls = self.info['classes']
            _base_folder = self.info['base_folder']

            _id = 0
            flat_data = []
            flat_targets = []
            for c in _cls:
                all_images = list(_cls[c])
                for img_name in all_images:
                    img = imread(join(_base_folder, join(c, img_name)))
                    img = 1. - np.reshape(img, (28, 28, 1)) / 255.
                    for rot in self._rotations:
                        img = rotate(img, rot, reshape=False)
                        self._loaded_images[c + os.path.sep + 'rot_' +
                                            str(rot)][img_name] = _id
                        _id += 1
                        flat_data.append(img)
                        # flat_targets maybe... no flat targets... they depend on the episode!!

            self._img_array = np.stack(flat_data)

            # end of class

    alphabets = os.listdir(folder)

    labels_and_images = OrderedDict()
    for alphabet in alphabets:
        base_folder = join(folder, alphabet)
        label_names = os.listdir(base_folder)  # all characters in one alphabet
        labels_and_images.update({
            alphabet + os.path.sep + ln: os.listdir(join(base_folder, ln))
            # all examples of each character
            for ln in label_names
        })

    # divide between training validation and test meta-datasets
    _rand = em.get_rand_state(_rand)
    all_clss = list(labels_and_images.keys())
    _rand.shuffle(all_clss)
    n_splits = n_splits or (0, 1200, 1300, len(all_clss))

    meta_dts = []
    for start, end in zip(n_splits, n_splits[1:]):
        meta_dts.append(
            OmniglotMetaDataset(info={
                'base_folder': folder,
                'classes':
                {k: labels_and_images[k]
                 for k in all_clss[start:end]},
                'one_hot_enc': one_hot_enc
            },
                                num_classes=std_num_classes,
                                num_examples=std_num_examples))

    return em.Datasets.from_list(meta_dts)
Example #7
0
def redivide_data(datasets,
                  partition_proportions=None,
                  shuffle=False,
                  filters=None,
                  maps=None,
                  balance_classes=False,
                  rand=None):
    """
    Function that redivides datasets. Can be use also to shuffle or filter or map examples.

    :param rand:
    :param balance_classes: # TODO RICCARDO
    :param datasets: original datasets, instances of class Dataset (works with get_data and get_targets for
                        compatibility with mnist datasets
    :param partition_proportions: (optional, default None)  list of fractions that can either sum up to 1 or less
                                    then one, in which case one additional partition is created with
                                    proportion 1 - sum(partition proportions).
                                    If None it will retain the same proportion of samples found in datasets
    :param shuffle: (optional, default False) if True shuffles the examples
    :param filters: (optional, default None) filter or list of filters: functions with signature
                        (data, target, index) -> boolean (accept or reject the sample)
    :param maps: (optional, default None) map or list of maps: functions with signature
                        (data, target, index) ->  (new_data, new_target) (maps the old sample to a new one,
                        possibly also to more
                        than one sample, for data augmentation)
    :return: a list of datasets of length equal to the (possibly augmented) partition_proportion
    """

    rnd = em.get_rand_state(rand)

    all_data = vstack([get_data(d) for d in datasets])
    all_labels = stack_or_concat([get_targets(d) for d in datasets])

    all_infos = np.concatenate([d.sample_info for d in datasets])

    N = all_data.shape[0]

    if partition_proportions:  # argument check
        partition_proportions = list([partition_proportions] if isinstance(
            partition_proportions, float) else partition_proportions)
        sum_proportions = sum(partition_proportions)
        assert sum_proportions <= 1, "partition proportions must sum up to at most one: %d" % sum_proportions
        if sum_proportions < 1.:
            partition_proportions += [1. - sum_proportions]
    else:
        partition_proportions = [
            1. * get_data(d).shape[0] / N for d in datasets
        ]

    if shuffle:
        if sp and isinstance(all_data, sp.sparse.csr.csr_matrix):
            raise NotImplementedError()
        # if sk_shuffle:  # TODO this does not work!!! find a way to shuffle these matrices while
        # keeping compatibility with tensorflow!
        #     all_data, all_labels, all_infos = sk_shuffle(all_data, all_labels, all_infos)
        # else:
        permutation = np.arange(all_data.shape[0])
        rnd.shuffle(permutation)

        all_data = all_data[permutation]
        all_labels = np.array(all_labels[permutation])
        all_infos = np.array(all_infos[permutation])

    if filters:
        if sp and isinstance(all_data, sp.sparse.csr.csr_matrix):
            raise NotImplementedError()
        filters = as_list(filters)
        data_triple = [(x, y, d)
                       for x, y, d in zip(all_data, all_labels, all_infos)]
        for fiat in filters:
            data_triple = [
                xy for i, xy in enumerate(data_triple)
                if fiat(xy[0], xy[1], xy[2], i)
            ]
        all_data = np.vstack([e[0] for e in data_triple])
        all_labels = np.vstack([e[1] for e in data_triple])
        all_infos = np.vstack([e[2] for e in data_triple])

    if maps:
        if sp and isinstance(all_data, sp.sparse.csr.csr_matrix):
            raise NotImplementedError()
        maps = as_list(maps)
        data_triple = [(x, y, d)
                       for x, y, d in zip(all_data, all_labels, all_infos)]
        for _map in maps:
            data_triple = [
                _map(xy[0], xy[1], xy[2], i)
                for i, xy in enumerate(data_triple)
            ]
        all_data = np.vstack([e[0] for e in data_triple])
        all_labels = np.vstack([e[1] for e in data_triple])
        all_infos = np.vstack([e[2] for e in data_triple])

    N = all_data.shape[0]
    assert N == all_labels.shape[0]

    calculated_partitions = reduce(
        lambda v1, v2: v1 + [sum(v1) + v2],
        [int(N * prp) for prp in partition_proportions], [0])
    calculated_partitions[-1] = N

    print('datasets.redivide_data:, computed partitions numbers -',
          calculated_partitions,
          'len all',
          N,
          end=' ')

    new_general_info_dict = {}
    for data in datasets:
        new_general_info_dict = {**new_general_info_dict, **data.info}

        if balance_classes:
            new_datasets = []
            forbidden_indices = np.empty(0, dtype=np.int64)
            for d1, d2 in zip(calculated_partitions[:-1],
                              calculated_partitions[1:-1]):
                indices = np.array(
                    get_indices_balanced_classes(d2 - d1, all_labels,
                                                 forbidden_indices))
                dataset = em.Dataset(data=all_data[indices],
                                     target=all_labels[indices],
                                     sample_info=all_infos[indices],
                                     info=new_general_info_dict)
                new_datasets.append(dataset)
                forbidden_indices = np.append(forbidden_indices, indices)
                test_if_balanced(dataset)
            remaining_indices = np.array(
                list(set(list(range(N))) - set(forbidden_indices)))
            new_datasets.append(
                em.Dataset(data=all_data[remaining_indices],
                           target=all_labels[remaining_indices],
                           sample_info=all_infos[remaining_indices],
                           info=new_general_info_dict))
        else:
            new_datasets = [
                em.Dataset(data=all_data[d1:d2],
                           target=all_labels[d1:d2],
                           sample_info=all_infos[d1:d2],
                           info=new_general_info_dict) for d1, d2 in
                zip(calculated_partitions, calculated_partitions[1:])
            ]

        print('DONE')

        return new_datasets
def train(metasets, ex_name, hyper_repr_model_builder, classifier_builder=None, saver=None, seed=0, MBS=4,
          available_devices=('/gpu:0', '/gpu:1'),
          mlr0=.001, mlr_decay=1.e-5, T=4, n_episodes_testing=600,
          print_every=1000, patience=40, restore_model=False,
          lr=0.1, learn_lr=True, process_fn=None):
    """
    Function for training an hyper-representation network.

    :param metasets: Datasets of MetaDatasets
    :param ex_name: name of the experiment
    :param hyper_repr_model_builder: builder for the representation model,
                                        function (input, name) -> `experiment_manager.Network`
    :param classifier_builder: optional builder for classifier model (if None then builds a linear model)
    :param saver: experiment_manager.Saver object
    :param seed:
    :param MBS: meta-batch size
    :param available_devices: distribute the computation among different GPUS!
    :param mlr0: initial meta learning rate
    :param mlr_decay:
    :param T: number of gradient steps for training ground models
    :param n_episodes_testing:
    :param print_every:
    :param patience:
    :param restore_model:
    :param lr: initial ground models learning rate
    :param learn_lr: True for optimizing the ground models learning rate
    :param process_fn: optinal hypergradient process function (like gradient clipping)

    :return: tuple: the saver object, the hyper-representation model and the list of experiments objects
    """
    if saver is None:
        saver = SAVER_EXP(metasets)

    T, ss, n_episodes_testing = setup(T, seed, n_episodes_testing, MBS)
    exs = [em.SLExperiment(metasets) for _ in range(MBS)]

    hyper_repr_model = hyper_repr_model_builder(exs[0].x, name=ex_name)
    if classifier_builder is None: classifier_builder = lambda inp, name: models.FeedForwardNet(
        inp, metasets.train.dim_target, name=name)

    io_optim, gs, meta_lr, oo_optim, farho = _optimizers(lr, mlr0, mlr_decay, learn_lr)

    for k, ex in enumerate(exs):
        with tf.device(available_devices[k % len(available_devices)]):
            ex.model = classifier_builder(hyper_repr_model.for_input(ex.x).out, 'Classifier_%s' % k)
            ex.errors['training'] = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(labels=ex.y, logits=ex.model.out)
            )
            ex.errors['validation'] = ex.errors['training']
            ex.scores['accuracy'] = tf.reduce_mean(tf.cast(
                tf.equal(tf.argmax(ex.y, 1), tf.argmax(ex.model.out, 1)), tf.float32),
                name='accuracy')

            optim_dict = farho.inner_problem(ex.errors['training'], io_optim, var_list=ex.model.var_list)
            farho.outer_problem(ex.errors['validation'], optim_dict, oo_optim, global_step=gs)

    farho.finalize(process_fn=process_fn)

    feed_dicts, just_train_on_dataset, mean_acc_on, cond = _helper_function(
        exs, n_episodes_testing, MBS, ss, farho, T)

    rand = em.get_rand_state(0)

    with saver.record(*_records(metasets, saver, hyper_repr_model, cond, ss, mean_acc_on, ex_name, meta_lr),
                      where='far', every=print_every, append_string=ex_name):
        tf.global_variables_initializer().run()
        if restore_model:
            saver.restore_model(hyper_repr_model)
        # ADD ONLY TESTING
        for _ in cond.early_stopping_sv(saver, patience):
            trfd, vfd = feed_dicts(metasets.train.generate_batch(MBS, rand=rand))

            farho.run(T[0], trfd, vfd)  # one iteration of optimization of representation variables (hyperparameters)

    return saver, hyper_repr_model, exs
Example #9
0
def meta_test_up_to_T(exp_dir,
                      metasets,
                      exs,
                      far_ho,
                      saver,
                      sess,
                      c_way,
                      k_shot,
                      lr,
                      n_test_episodes,
                      MBS,
                      seed,
                      T,
                      iterations=list(range(10000))):
    meta_test_str = str(c_way) + 'way_' + str(k_shot) + 'shot_' + str(
        lr) + 'lr' + str(n_test_episodes) + 'ep'

    n_test_batches = n_test_episodes // MBS
    rand = em.get_rand_state(seed)

    valid_batches = BatchQueueMock(metasets.validation, n_test_batches, MBS,
                                   rand)
    test_batches = BatchQueueMock(metasets.test, n_test_batches, MBS, rand)
    train_batches = BatchQueueMock(metasets.train, n_test_batches, MBS, rand)

    print('\nMeta-testing {} (over {} eps)...'.format(meta_test_str,
                                                      n_test_episodes))

    test_results = {
        'valid_test': [],
        'test_test': [],
        'train_test': [],
        'time': [],
        'n_test_episodes': n_test_episodes,
        'episodes': [],
        'iterations': []
    }

    test_result_path = os.path.join(exp_dir,
                                    meta_test_str + 'noTrain_results.pickle')

    start_time = time.time()
    for i in iterations:
        model_file = os.path.join(exp_dir, 'model' + str(i))
        if tf.train.checkpoint_exists(model_file):
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

            test_results['iterations'].append(i)
            test_results['episodes'].append(i * MBS)

            valid_result = accuracy_on_up_to_T(valid_batches, exs, far_ho,
                                               sess, T)
            test_result = accuracy_on_up_to_T(test_batches, exs, far_ho, sess,
                                              T)
            train_result = accuracy_on_up_to_T(train_batches, exs, far_ho,
                                               sess, T)

            duration = time.time() - start_time

            test_results['time'].append(duration)

            for t in range(T):

                valid_test = (np.mean(valid_result[1][t]),
                              np.std(valid_result[1][t]))
                test_test = (np.mean(test_result[1][t]),
                             np.std(test_result[1][t]))
                train_test = (np.mean(train_result[1][t]),
                              np.std(train_result[1][t]))

                if t >= len(test_results['valid_test']):
                    test_results['valid_test'].append({'mean': [], 'std': []})
                    test_results['test_test'].append({'mean': [], 'std': []})
                    test_results['train_test'].append({'mean': [], 'std': []})

                test_results['valid_test'][t]['mean'].append(valid_test[0])
                test_results['test_test'][t]['mean'].append(test_test[0])
                test_results['train_test'][t]['mean'].append(train_test[0])

                test_results['valid_test'][t]['std'].append(valid_test[1])
                test_results['test_test'][t]['std'].append(test_test[1])
                test_results['train_test'][t]['std'].append(train_test[1])

                print(
                    'valid-test_test acc T=%d (%d meta_it)(%.2fs): %.4f (%.4f), %.4f (%.4f),'
                    '  %.4f (%.4f)' %
                    (t + 1, i, duration, train_test[0], train_test[1],
                     valid_test[0], valid_test[1], test_test[0], test_test[1]))

                #print('valid-test_test acc T=%d (%d meta_it)(%.2fs): %.4f (%.4f),'
                #      '  %.4f (%.4f)' % (t+1, i, duration, valid_test[0], valid_test[1],
                #                         test_test[0], test_test[1]))

            save_obj(test_result_path, test_results)

    return test_results
Example #10
0
def meta_train(exp_dir, metasets, exs, far_ho, saver, sess, n_test_episodes,
               MBS, seed, resume, T, n_meta_iterations, print_interval,
               save_interval):
    # use workers to fill the batches queues (is it worth it?)

    result_path = os.path.join(exp_dir, 'results.pickle')

    tf.global_variables_initializer().run(session=sess)

    n_test_batches = n_test_episodes // MBS
    rand = em.get_rand_state(seed)

    results = {
        'train_train': {
            'mean': [],
            'std': []
        },
        'train_test': {
            'mean': [],
            'std': []
        },
        'test_test': {
            'mean': [],
            'std': []
        },
        'valid_test': {
            'mean': [],
            'std': []
        },
        'outer_losses': {
            'mean': [],
            'std': []
        },
        'learning_rate': [],
        'iterations': [],
        'episodes': [],
        'time': []
    }

    start_time = time.time()

    resume_itr = 0
    if resume:
        model_file = tf.train.latest_checkpoint(exp_dir)
        if model_file:
            print("Restoring results from " + result_path)
            results = load_obj(result_path)
            start_time = results['time'][-1]

            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:]) + 1
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
    ''' Meta-Train '''
    train_batches = BatchQueueMock(metasets.train, 1, MBS, rand)
    valid_batches = BatchQueueMock(metasets.validation, n_test_batches, MBS,
                                   rand)
    test_batches = BatchQueueMock(metasets.test, n_test_batches, MBS, rand)

    print(
        '\nIteration quantities: train_train acc, train_test acc, valid_test, acc'
        ' test_test acc mean(std) over %d episodes' % n_test_episodes)
    with sess.as_default():
        inner_losses = []
        for meta_it in range(resume_itr, n_meta_iterations):
            tr_fd, v_fd = feed_dicts(train_batches.get()[0], exs)

            far_ho.run(T, tr_fd, v_fd)
            # inner_losses.append(far_ho.inner_losses)

            outer_losses = [
                sess.run(ex.errors['validation'], v_fd) for ex in exs
            ]
            outer_losses_moments = (np.mean(outer_losses),
                                    np.std(outer_losses))
            results['outer_losses']['mean'].append(outer_losses_moments[0])
            results['outer_losses']['std'].append(outer_losses_moments[1])

            # print('inner_losses: ', inner_losses[-1])

            if meta_it % print_interval == 0 or meta_it == n_meta_iterations - 1:
                results['iterations'].append(meta_it)
                results['episodes'].append(meta_it * MBS)

                train_result = accuracy_on(train_batches, exs, far_ho, sess, T)
                test_result = accuracy_on(test_batches, exs, far_ho, sess, T)
                valid_result = accuracy_on(valid_batches, exs, far_ho, sess, T)

                train_train = (np.mean(train_result[0]),
                               np.std(train_result[0]))
                train_test = (np.mean(train_result[1]),
                              np.std(train_result[1]))
                valid_test = (np.mean(valid_result[1]),
                              np.std(valid_result[1]))
                test_test = (np.mean(test_result[1]), np.std(test_result[1]))

                duration = time.time() - start_time
                results['time'].append(duration)

                results['train_train']['mean'].append(train_train[0])
                results['train_test']['mean'].append(train_test[0])
                results['valid_test']['mean'].append(valid_test[0])
                results['test_test']['mean'].append(test_test[0])

                results['train_train']['std'].append(train_train[1])
                results['train_test']['std'].append(train_test[1])
                results['valid_test']['std'].append(valid_test[1])
                results['test_test']['std'].append(test_test[1])

                results['inner_losses'] = inner_losses

                print('mean outer losses: {}'.format(outer_losses_moments[1]))

                print('it %d, ep %d (%.2fs): %.3f, %.3f, %.3f, %.3f' %
                      (meta_it, meta_it * MBS, duration, train_train[0],
                       train_test[0], valid_test[0], test_test[0]))

                lr = sess.run(["lr:0"])[0]
                print('lr: {}'.format(lr))

                # do_plot(logdir, results)

            if meta_it % save_interval == 0 or meta_it == n_meta_iterations - 1:
                saver.save(sess, exp_dir + '/model' + str(meta_it))
                save_obj(result_path, results)

        return results