示例#1
0
def random_regression_datasets(n_samples,
                               features=100,
                               outs=1,
                               informative=0.1,
                               partition_proportions=(0.5, 0.3),
                               rnd=None,
                               **mk_rgr_kwargs):
    rnd_state = dl.get_rand_state(rnd)
    X, Y, w = make_regression(n_samples,
                              features,
                              int(features * informative),
                              outs,
                              random_state=rnd_state,
                              coef=True,
                              **mk_rgr_kwargs)
    if outs == 1:
        Y = np.reshape(Y, (n_samples, 1))

    print("range of Y", np.min(Y), np.max(Y))
    info = merge_dicts({
        "informative": informative,
        "random_seed": rnd,
        "w": w
    }, mk_rgr_kwargs)
    name = dl.em_utils.name_from_dict(info, "w")
    dt = dl.Dataset(X, Y, name=name, info=info)
    datasets = dl.Datasets.from_list(redivide_data([dt],
                                                   partition_proportions))
    print(
        "conditioning of X^T X",
        np.linalg.cond(datasets.train.data.T @ datasets.train.data),
    )
    return datasets
示例#2
0
def random_classification_datasets(n_samples,
                                   features=100,
                                   classes=2,
                                   informative=0.1,
                                   partition_proportions=(0.5, 0.3),
                                   rnd=None,
                                   one_hot=True,
                                   **mk_cls_kwargs):
    rnd_state = dl.get_rand_state(rnd)
    X, Y = make_classification(n_samples,
                               features,
                               n_classes=classes,
                               random_state=rnd_state,
                               **mk_cls_kwargs)
    if one_hot:
        Y = to_one_hot_enc(Y)

    print("range of Y", np.min(Y), np.max(Y))
    info = merge_dicts({
        "informative": informative,
        "random_seed": rnd
    }, mk_cls_kwargs)
    name = dl.em_utils.name_from_dict(info, "w")
    dt = dl.Dataset(X, Y, name=name, info=info)
    datasets = dl.Datasets.from_list(redivide_data([dt],
                                                   partition_proportions))
    print(
        "conditioning of X^T X",
        np.linalg.cond(datasets.train.data.T @ datasets.train.data),
    )
    return datasets
示例#3
0
    def generate_datasets(self,
                          rand=None,
                          num_classes=None,
                          num_examples=None):
        rand = dl.get_rand_state(rand)

        if not num_examples:
            num_examples = self.kwargs["num_examples"]
        if not num_classes:
            num_classes = self.kwargs["num_classes"]

        clss = self._loaded_images if self._loaded_images else self.info[
            "classes"]

        random_classes = rand.choice(list(clss.keys()),
                                     size=(num_classes, ),
                                     replace=False)
        rand_class_dict = {rnd: k for k, rnd in enumerate(random_classes)}

        _dts = []
        for ns in as_tuple_or_list(num_examples):
            classes = balanced_choice_wr(random_classes, ns, rand)

            all_images = {cls: list(clss[cls]) for cls in classes}
            data, targets, sample_info = [], [], []
            for c in classes:
                rand.shuffle(all_images[c])
                img_name = all_images[c][0]
                all_images[c].remove(img_name)
                sample_info.append({"name": img_name, "label": c})
                data.append(clss[c][img_name])
                targets.append(rand_class_dict[c])

            if self.info["one_hot_enc"]:
                targets = dl.to_one_hot_enc(targets, dimension=num_classes)

            _dts.append(
                dl.Dataset(
                    data=np.array(np.stack(data)),
                    target=targets,
                    sample_info=sample_info,
                    info={"all_classes": random_classes},
                ))
        return dl.Datasets.from_list(_dts)
示例#4
0
def meta_test_up_to_T(
    exp_dir,
    metasets,
    exs,
    pybml_ho,
    saver,
    sess,
    c_way,
    k_shot,
    lr,
    n_test_episodes,
    MBS,
    seed,
    T,
    iterations=list(range(10000)),
):
    meta_test_str = (
        str(c_way)
        + "way_"
        + str(k_shot)
        + "shot_"
        + str(lr)
        + "lr"
        + str(n_test_episodes)
        + "ep"
    )

    n_test_batches = n_test_episodes // MBS
    rand = dl.get_rand_state(seed)

    valid_batches = BatchQueueMock(metasets.validation, n_test_batches, MBS, rand)
    test_batches = BatchQueueMock(metasets.test, n_test_batches, MBS, rand)
    train_batches = BatchQueueMock(metasets.train, n_test_batches, MBS, rand)

    print("\nMeta-testing {} (over {} eps)...".format(meta_test_str, n_test_episodes))

    test_results = {
        "valid_test": [],
        "test_test": [],
        "train_test": [],
        "time": [],
        "n_test_episodes": n_test_episodes,
        "episodes": [],
        "iterations": [],
    }

    test_result_path = os.path.join(exp_dir, meta_test_str + "noTrain_results.pickle")

    start_time = time.time()
    for i in iterations:
        model_file = os.path.join(exp_dir, "model" + str(i))
        if tf.train.checkpoint_exists(model_file):
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

            test_results["iterations"].append(i)
            test_results["episodes"].append(i * MBS)

            valid_result = accuracy_on_up_to_T(valid_batches, exs, pybml_ho, sess, T)
            test_result = accuracy_on_up_to_T(test_batches, exs, pybml_ho, sess, T)
            train_result = accuracy_on_up_to_T(train_batches, exs, pybml_ho, sess, T)

            duration = time.time() - start_time

            test_results["time"].append(duration)

            for t in range(T):

                valid_test = (np.mean(valid_result[1][t]), np.std(valid_result[1][t]))
                test_test = (np.mean(test_result[1][t]), np.std(test_result[1][t]))
                train_test = (np.mean(train_result[1][t]), np.std(train_result[1][t]))

                if t >= len(test_results["valid_test"]):
                    test_results["valid_test"].append({"mean": [], "std": []})
                    test_results["test_test"].append({"mean": [], "std": []})
                    test_results["train_test"].append({"mean": [], "std": []})

                test_results["valid_test"][t]["mean"].append(valid_test[0])
                test_results["test_test"][t]["mean"].append(test_test[0])
                test_results["train_test"][t]["mean"].append(train_test[0])

                test_results["valid_test"][t]["std"].append(valid_test[1])
                test_results["test_test"][t]["std"].append(test_test[1])
                test_results["train_test"][t]["std"].append(train_test[1])

                print(
                    "valid-test_test acc T=%d (%d meta_it)(%.2fs): %.4f (%.4f), %.4f (%.4f),"
                    "  %.4f (%.4f)"
                    % (
                        t + 1,
                        i,
                        duration,
                        train_test[0],
                        train_test[1],
                        valid_test[0],
                        valid_test[1],
                        test_test[0],
                        test_test[1],
                    )
                )

                # print('valid-test_test acc T=%d (%d meta_it)(%.2fs): %.4f (%.4f),'
                #      '  %.4f (%.4f)' % (t+1, i, duration, valid_test[0], valid_test[1],
                #                         test_test[0], test_test[1]))

            save_obj(test_result_path, test_results)

    return test_results
示例#5
0
def meta_train(
    exp_dir,
    metasets,
    exs,
    pybml_ho,
    saver,
    sess,
    n_test_episodes,
    MBS,
    seed,
    resume,
    T,
    n_meta_iterations,
    print_interval,
    save_interval,
):
    # use workers to fill the batches queues (is it worth it?)

    result_path = os.path.join(exp_dir, "results.pickle")
    tf.global_variables_initializer().run(session=sess)
    n_test_batches = n_test_episodes // MBS
    rand = dl.get_rand_state(seed)

    results = {
        "train_train": {"mean": [], "std": []},
        "train_test": {"mean": [], "std": []},
        "test_test": {"mean": [], "std": []},
        "valid_test": {"mean": [], "std": []},
        "outer_losses": {"mean": [], "std": []},
        "learning_rate": [],
        "iterations": [],
        "episodes": [],
        "time": [],
        "alpha": [],
    }

    resume_itr = 0
    if resume:
        model_file = tf.train.latest_checkpoint(exp_dir)
        if model_file:
            print("Restoring results from " + result_path)
            results = load_obj(result_path)

            ind1 = model_file.index("model")
            resume_itr = int(model_file[ind1 + 5 :]) + 1
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    """ Meta-Train """
    train_batches = BatchQueueMock(metasets.train, 1, MBS, rand)
    valid_batches = BatchQueueMock(metasets.validation, n_test_batches, MBS, rand)
    test_batches = BatchQueueMock(metasets.test, n_test_batches, MBS, rand)

    start_time = time.time()
    print(
        "\nIteration quantities: train_train acc, train_test acc, valid_test, acc"
        " test_test acc mean(std) over %d episodes" % n_test_episodes
    )
    with sess.as_default():
        inner_losses = []
        for meta_it in range(resume_itr, n_meta_iterations):
            tr_fd, v_fd = utils.feed_dicts(train_batches.get_all_batches()[0], exs)
            pybml_ho.run(tr_fd, v_fd)

            duration = time.time() - start_time

            results["time"].append(duration)
            outer_losses = []
            for _, ex in enumerate(exs):
                outer_losses.append(
                    sess.run(
                        ex.errors["validation"], boml.utils.merge_dicts(tr_fd, v_fd)
                    )
                )
            outer_losses_moments = (np.mean(outer_losses), np.std(outer_losses))
            results["outer_losses"]["mean"].append(outer_losses_moments[0])
            results["outer_losses"]["std"].append(outer_losses_moments[1])

            if meta_it % print_interval == 0 or meta_it == n_meta_iterations - 1:
                results["iterations"].append(meta_it)
                results["episodes"].append(meta_it * MBS)
                if "alpha" in pybml_ho.param_dict.keys():
                    alpha_moment = pybml_ho.param_dict["alpha"].eval()
                    print("alpha_itr" + str(meta_it) + ": ", alpha_moment)
                    results["alpha"].append(alpha_moment)
                if "s" in pybml_ho.param_dict.keys():
                    s = sess.run(["s:0"])[0]
                    print("s: {}".format(s))
                if "t" in pybml_ho.param_dict.keys():
                    t = sess.run(["t:0"])[0]
                    print("t: {}".format(t))

                train_result = accuracy_on(train_batches, exs, pybml_ho, sess, T)
                test_result = accuracy_on(test_batches, exs, pybml_ho, sess, T)
                valid_result = accuracy_on(valid_batches, exs, pybml_ho, sess, T)
                train_train = (np.mean(train_result[0]), np.std(train_result[0]))
                train_test = (np.mean(train_result[1]), np.std(train_result[1]))
                valid_test = (np.mean(valid_result[1]), np.std(valid_result[1]))
                test_test = (np.mean(test_result[1]), np.std(test_result[1]))

                results["train_train"]["mean"].append(train_train[0])
                results["train_test"]["mean"].append(train_test[0])
                results["valid_test"]["mean"].append(valid_test[0])
                results["test_test"]["mean"].append(test_test[0])

                results["train_train"]["std"].append(train_train[1])
                results["train_test"]["std"].append(train_test[1])
                results["valid_test"]["std"].append(valid_test[1])
                results["test_test"]["std"].append(test_test[1])

                results["inner_losses"] = inner_losses

                print("mean outer losses: {}".format(outer_losses_moments[0]))

                print(
                    "it %d, ep %d (%.5fs): %.5f, %.5f, %.5f, %.5f"
                    % (
                        meta_it,
                        meta_it * MBS,
                        duration,
                        train_train[0],
                        train_test[0],
                        valid_test[0],
                        test_test[0],
                    )
                )

                lr = sess.run(["lr:0"])[0]
                print("lr: {}".format(lr))

                # do_plot(logdir, results)

            if meta_it % save_interval == 0 or meta_it == n_meta_iterations - 1:
                saver.save(sess, exp_dir + "/model" + str(meta_it))
                save_obj(result_path, results)

            start_time = time.time()

        return results
示例#6
0
    def generate_datasets(self,
                          rand=None,
                          num_classes=None,
                          num_examples=None,
                          wait_for_n_min=None):

        rand = dl.get_rand_state(rand)

        if wait_for_n_min:
            import time

            while not self.check_loaded_images(wait_for_n_min):
                time.sleep(5)

        if not num_examples:
            num_examples = self.kwargs["num_examples"]
        if not num_classes:
            num_classes = self.kwargs["num_classes"]

        clss = self._loaded_images if self._loaded_images else self.info[
            "classes"]

        random_classes = rand.choice(list(clss.keys()),
                                     size=(num_classes, ),
                                     replace=False)
        rand_class_dict = {rnd: k for k, rnd in enumerate(random_classes)}

        _dts = []
        for ns in dl.as_tuple_or_list(num_examples):
            classes = balanced_choice_wr(random_classes, ns, rand)

            all_images = {cls: list(clss[cls]) for cls in classes}
            data, targets, sample_info = [], [], []
            for c in classes:
                rand.shuffle(all_images[c])
                img_name = all_images[c][0]
                all_images[c].remove(img_name)
                sample_info.append({"name": img_name, "label": c})

                if self._loaded_images:
                    data.append(clss[c][img_name])
                else:
                    from imageio import imread

                    data.append(
                        np.array(
                            Image.fromarray(
                                imread(
                                    join(self.info["base_folder"],
                                         join(c, img_name)))).resize(
                                             size=(self.info["resize"],
                                                   self.info["resize"]))) /
                        255.0)
                targets.append(rand_class_dict[c])

            if self.info["one_hot_enc"]:
                targets = to_one_hot_enc(targets, dimension=num_classes)

            _dts.append(
                dl.Dataset(
                    data=np.array(np.stack(data)),
                    target=targets,
                    sample_info=sample_info,
                    info={"all_classes": random_classes},
                ))
        return dl.Datasets.from_list(_dts)
示例#7
0
def meta_omniglot(
    folder=OMNIGLOT_RESIZED,
    std_num_classes=None,
    examples_train=0,
    examples_test=0,
    one_hot_enc=True,
    _rand=0,
    n_splits=None,
):
    """
    Loading function for Omniglot dataset in learning-to-learn version. Use image data as obtained from
    https://github.com/cbfinn/maml/blob/master/data/omniglot_resized/resize_images.py

    :param folder: root folder name.
    :param std_num_classes: standard number of classes for N-way classification
    :param examples_train:standard number of examples to be picked in each generated per classes for training
    (eg .1 shot, examples_train=1)
    :param examples_test: standard number of examples to be picked in each generated per classes for testing
    :param one_hot_enc: one hot encoding
    :param _rand: random seed or RandomState for generate training, validation, testing meta-datasets
                    split
    :param n_splits: num classes per split
    :return: a Datasets of MetaDataset s
    """
    assert (examples_train >
            0), "Wrong initialization for number of examples used for training"
    if examples_test > 0:
        std_num_examples = (
            examples_train * std_num_classes,
            examples_test * std_num_classes,
        )
    else:
        std_num_examples = examples_train * std_num_classes
    alphabets = os.listdir(folder)

    labels_and_images = OrderedDict()
    for alphabet in alphabets:
        base_folder = join(folder, alphabet)
        label_names = os.listdir(base_folder)  # all characters in one alphabet
        labels_and_images.update({
            alphabet + os.path.sep + ln: os.listdir(join(base_folder, ln))
            # all examples of each character
            for ln in label_names
        })

    # divide between training validation and test meta-datasets
    _rand = dl.get_rand_state(_rand)
    all_clss = list(labels_and_images.keys())
    _rand.shuffle(all_clss)
    n_splits = n_splits or (0, 1200, 1300, len(all_clss))

    meta_dts = []
    for start, end in zip(n_splits, n_splits[1:]):
        meta_dts.append(
            OmniglotMetaDataset(
                info={
                    "base_folder": folder,
                    "classes":
                    {k: labels_and_images[k]
                     for k in all_clss[start:end]},
                    "one_hot_enc": one_hot_enc,
                },
                num_classes=std_num_classes,
                num_examples=std_num_examples,
            ))

    return dl.Datasets.from_list(meta_dts)
示例#8
0
def meta_omniglot_v2(
    folder=OMNIGLOT_RESIZED,
    std_num_classes=None,
    examples_train=None,
    examples_test=None,
    one_hot_enc=True,
    _rand=0,
    n_splits=None,
):
    """
    Loading function for Omniglot dataset in learning-to-learn version. Use image data as obtained from
    https://github.com/cbfinn/maml/blob/master/data/omniglot_resized/resize_images.py

    :param folder: root folder name.
    :param std_num_classes: standard number of classes for N-way classification
    :param examples_train:standard number of examples to be picked in each generated per classes for training
    (eg .1 shot, examples_train=1)
    :param examples_test: standard number of examples to be picked in each generated per classes for testing
    :param one_hot_enc: one hot encoding
    :param _rand: random seed or RandomState for generate training, validation, testing meta-datasets
                    split
    :param n_splits: num classes per split
    :return: a Datasets of MetaDataset s
    """

    class OmniglotMetaDataset(MetaDataset):
        def __init__(
            self,
            info=None,
            rotations=None,
            name="Omniglot",
            num_classes=None,
            num_examples=None,
        ):
            super().__init__(
                info, name=name, num_classes=num_classes, num_examples=num_examples
            )
            self._loaded_images = defaultdict(lambda: {})
            self._rotations = rotations or [0, 90, 180, 270]
            self.num_classes = num_classes
            assert len(num_examples) > 0
            self.examples_train = int(num_examples[0] / num_classes)
            self._img_array = None
            self.load_all()

        def generate_datasets(self, rand=None, num_classes=None, num_examples=None):
            rand = dl.get_rand_state(rand)

            if not num_examples:
                num_examples = self.kwargs["num_examples"]
            if not num_classes:
                num_classes = self.kwargs["num_classes"]

            clss = self._loaded_images if self._loaded_images else self.info["classes"]

            random_classes = rand.choice(
                list(clss.keys()), size=(num_classes,), replace=False
            )
            rand_class_dict = {rnd: k for k, rnd in enumerate(random_classes)}

            _dts = []
            for ns in dl.as_tuple_or_list(num_examples):
                classes = balanced_choice_wr(random_classes, ns, rand)

                all_images = {cls: list(clss[cls]) for cls in classes}
                indices, targets = [], []
                for c in classes:
                    rand.shuffle(all_images[c])
                    img_name = all_images[c][0]
                    all_images[c].remove(img_name)
                    # sample_info.append({'name': img_name, 'label': c})
                    indices.append(clss[c][img_name])
                    targets.append(rand_class_dict[c])

                if self.info["one_hot_enc"]:
                    targets = dl.to_one_hot_enc(targets, dimension=num_classes)

                data = self._img_array[indices]

                _dts.append(dl.Dataset(data=data, target=targets))
            return dl.Datasets.from_list(_dts)

        def load_all(self):
            from scipy.ndimage import imread
            from scipy.ndimage.interpolation import rotate

            _cls = self.info["classes"]
            _base_folder = self.info["base_folder"]

            _id = 0
            flat_data = []
            flat_targets = []
            for c in _cls:
                all_images = list(_cls[c])
                for img_name in all_images:
                    img = imread(join(_base_folder, join(c, img_name)))
                    img = 1.0 - np.reshape(img, (28, 28, 1)) / 255.0
                    for rot in self._rotations:
                        img = rotate(img, rot, reshape=False)
                        self._loaded_images[c + os.path.sep + "rot_" + str(rot)][
                            img_name
                        ] = _id
                        _id += 1
                        flat_data.append(img)
                        # flat_targets maybe... no flat targets... they depend on the episode!!

            self._img_array = np.stack(flat_data)

            # end of class

    std_num_examples = (
        examples_train * std_num_classes,
        examples_test * std_num_classes,
    )
    alphabets = os.listdir(folder)

    labels_and_images = OrderedDict()
    for alphabet in alphabets:
        base_folder = join(folder, alphabet)
        label_names = os.listdir(base_folder)  # all characters in one alphabet
        labels_and_images.update(
            {
                alphabet + os.path.sep + ln: os.listdir(join(base_folder, ln))
                # all examples of each character
                for ln in label_names
            }
        )

    # divide between training validation and test meta-datasets
    _rand = dl.get_rand_state(_rand)
    all_clss = list(labels_and_images.keys())
    _rand.shuffle(all_clss)
    n_splits = n_splits or (0, 1200, 1300, len(all_clss))

    meta_dts = []
    for start, end in zip(n_splits, n_splits[1:]):
        meta_dts.append(
            OmniglotMetaDataset(
                info={
                    "base_folder": folder,
                    "classes": {k: labels_and_images[k] for k in all_clss[start:end]},
                    "one_hot_enc": one_hot_enc,
                },
                num_classes=std_num_classes,
                num_examples=std_num_examples,
            )
        )

    return dl.Datasets.from_list(meta_dts)
示例#9
0
def meta_test_up_to_T(exp_dir,
                      metasets,
                      exs,
                      pybml_ho,
                      saver,
                      sess,
                      c_way,
                      k_shot,
                      lr,
                      n_test_episodes,
                      MBS,
                      seed,
                      T,
                      iterations=list(range(10000))):
    meta_test_str = str(c_way) + 'way_' + str(k_shot) + 'shot_' + str(
        lr) + 'lr' + str(n_test_episodes) + 'ep'

    n_test_batches = n_test_episodes // MBS
    rand = dl.get_rand_state(seed)

    valid_batches = BatchQueueMock(metasets.validation, n_test_batches, MBS,
                                   rand)
    test_batches = BatchQueueMock(metasets.test, n_test_batches, MBS, rand)
    train_batches = BatchQueueMock(metasets.train, n_test_batches, MBS, rand)

    print('\nMeta-testing {} (over {} eps)...'.format(meta_test_str,
                                                      n_test_episodes))

    test_results = {
        'valid_test': [],
        'test_test': [],
        'train_test': [],
        'time': [],
        'n_test_episodes': n_test_episodes,
        'episodes': [],
        'iterations': []
    }

    test_result_path = os.path.join(exp_dir,
                                    meta_test_str + 'noTrain_results.pickle')

    start_time = time.time()
    for i in iterations:
        model_file = os.path.join(exp_dir, 'model' + str(i))
        if tf.train.checkpoint_exists(model_file):
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

            test_results['iterations'].append(i)
            test_results['episodes'].append(i * MBS)

            valid_result = accuracy_on_up_to_T(valid_batches, exs, pybml_ho,
                                               sess, T)
            test_result = accuracy_on_up_to_T(test_batches, exs, pybml_ho,
                                              sess, T)
            train_result = accuracy_on_up_to_T(train_batches, exs, pybml_ho,
                                               sess, T)

            duration = time.time() - start_time

            test_results['time'].append(duration)

            for t in range(T):

                valid_test = (np.mean(valid_result[1][t]),
                              np.std(valid_result[1][t]))
                test_test = (np.mean(test_result[1][t]),
                             np.std(test_result[1][t]))
                train_test = (np.mean(train_result[1][t]),
                              np.std(train_result[1][t]))

                if t >= len(test_results['valid_test']):
                    test_results['valid_test'].append({'mean': [], 'std': []})
                    test_results['test_test'].append({'mean': [], 'std': []})
                    test_results['train_test'].append({'mean': [], 'std': []})

                test_results['valid_test'][t]['mean'].append(valid_test[0])
                test_results['test_test'][t]['mean'].append(test_test[0])
                test_results['train_test'][t]['mean'].append(train_test[0])

                test_results['valid_test'][t]['std'].append(valid_test[1])
                test_results['test_test'][t]['std'].append(test_test[1])
                test_results['train_test'][t]['std'].append(train_test[1])

                print(
                    'valid-test_test acc T=%d (%d meta_it)(%.2fs): %.4f (%.4f), %.4f (%.4f),'
                    '  %.4f (%.4f)' %
                    (t + 1, i, duration, train_test[0], train_test[1],
                     valid_test[0], valid_test[1], test_test[0], test_test[1]))

                # print('valid-test_test acc T=%d (%d meta_it)(%.2fs): %.4f (%.4f),'
                #      '  %.4f (%.4f)' % (t+1, i, duration, valid_test[0], valid_test[1],
                #                         test_test[0], test_test[1]))

            save_obj(test_result_path, test_results)

    return test_results
示例#10
0
def meta_train(exp_dir, metasets, exs, pybml_ho, saver, sess, n_test_episodes,
               MBS, seed, resume, T, n_meta_iterations, print_interval,
               save_interval):
    # use workers to fill the batches queues (is it worth it?)

    result_path = os.path.join(exp_dir, 'results.pickle')
    tf.global_variables_initializer().run(session=sess)
    n_test_batches = n_test_episodes // MBS
    rand = dl.get_rand_state(seed)

    results = {
        'train_train': {
            'mean': [],
            'std': []
        },
        'train_test': {
            'mean': [],
            'std': []
        },
        'test_test': {
            'mean': [],
            'std': []
        },
        'valid_test': {
            'mean': [],
            'std': []
        },
        'outer_losses': {
            'mean': [],
            'std': []
        },
        'learning_rate': [],
        'iterations': [],
        'episodes': [],
        'time': [],
        'alpha': []
    }

    resume_itr = 0
    if resume:
        model_file = tf.train.latest_checkpoint(exp_dir)
        if model_file:
            print("Restoring results from " + result_path)
            results = load_obj(result_path)

            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:]) + 1
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
    ''' Meta-Train '''
    train_batches = BatchQueueMock(metasets.train, 1, MBS, rand)
    valid_batches = BatchQueueMock(metasets.validation, n_test_batches, MBS,
                                   rand)
    test_batches = BatchQueueMock(metasets.test, n_test_batches, MBS, rand)

    start_time = time.time()
    print(
        '\nIteration quantities: train_train acc, train_test acc, valid_test, acc'
        ' test_test acc mean(std) over %d episodes' % n_test_episodes)
    with sess.as_default():
        inner_losses = []
        for meta_it in range(resume_itr, n_meta_iterations):
            tr_fd, v_fd = feed_dicts(train_batches.get_all_batches()[0], exs)
            pybml_ho.run(tr_fd, v_fd)

            duration = time.time() - start_time

            results['time'].append(duration)
            outer_losses = []
            for _, ex in enumerate(exs):
                outer_losses.append(
                    sess.run(ex.errors['validation'],
                             boml.utils.merge_dicts(tr_fd, v_fd)))
            outer_losses_moments = (np.mean(outer_losses),
                                    np.std(outer_losses))
            results['outer_losses']['mean'].append(outer_losses_moments[0])
            results['outer_losses']['std'].append(outer_losses_moments[1])

            if meta_it % print_interval == 0 or meta_it == n_meta_iterations - 1:
                results['iterations'].append(meta_it)
                results['episodes'].append(meta_it * MBS)
                if 'alpha' in pybml_ho.param_dict.keys():
                    alpha_moment = pybml_ho.param_dict['alpha'].eval()
                    print('alpha_itr' + str(meta_it) + ': ', alpha_moment)
                    results['alpha'].append(alpha_moment)
                if 's' in pybml_ho.param_dict.keys():
                    s = sess.run(["s:0"])[0]
                    print('s: {}'.format(s))
                if 't' in pybml_ho.param_dict.keys():
                    t = sess.run(["t:0"])[0]
                    print('t: {}'.format(t))

                train_result = accuracy_on(train_batches, exs, pybml_ho, sess,
                                           T)
                test_result = accuracy_on(test_batches, exs, pybml_ho, sess, T)
                valid_result = accuracy_on(valid_batches, exs, pybml_ho, sess,
                                           T)
                train_train = (np.mean(train_result[0]),
                               np.std(train_result[0]))
                train_test = (np.mean(train_result[1]),
                              np.std(train_result[1]))
                valid_test = (np.mean(valid_result[1]),
                              np.std(valid_result[1]))
                test_test = (np.mean(test_result[1]), np.std(test_result[1]))

                results['train_train']['mean'].append(train_train[0])
                results['train_test']['mean'].append(train_test[0])
                results['valid_test']['mean'].append(valid_test[0])
                results['test_test']['mean'].append(test_test[0])

                results['train_train']['std'].append(train_train[1])
                results['train_test']['std'].append(train_test[1])
                results['valid_test']['std'].append(valid_test[1])
                results['test_test']['std'].append(test_test[1])

                results['inner_losses'] = inner_losses

                print('mean outer losses: {}'.format(outer_losses_moments[0]))

                print('it %d, ep %d (%.5fs): %.5f, %.5f, %.5f, %.5f' %
                      (meta_it, meta_it * MBS, duration, train_train[0],
                       train_test[0], valid_test[0], test_test[0]))

                lr = sess.run(["lr:0"])[0]
                print('lr: {}'.format(lr))

                # do_plot(logdir, results)

            if meta_it % save_interval == 0 or meta_it == n_meta_iterations - 1:
                saver.save(sess, exp_dir + '/model' + str(meta_it))
                save_obj(result_path, results)

            start_time = time.time()

        return results