Example #1
0
def build(metasets, learn_lr, lr0, MBS, T, mlr0,mlr_decay=1.e-5, process_fn=None, method='MetaInit',inner_method='Simple', outer_method='Simple',
          use_T=False, use_Warp=False, first_order=False):

    exs = [dl.BOMLExperiment(metasets) for _ in range(MBS)]

    boml_ho = boml.BOMLOptimizer(method=method, inner_method=inner_method, outer_method=outer_method,experiments=exs)
    meta_model = boml_ho.meta_learner(_input=exs[0].x, dataset=metasets, meta_model='V1',
                                        name='HyperRepr', use_T=use_T,use_Warp=use_Warp)

    for k, ex in enumerate(exs):
        ex.model = boml_ho.base_learner(_input=ex.x, meta_learner=meta_model,
                                         name='Task_Net_%s' % k)
        ex.errors['training'] = boml.utils.cross_entropy(pred=ex.model.out, label=ex.y, method=method)
        ex.scores['accuracy'] = boml.utils.classification_acc(pred=ex.model.out,label=ex.y)
        ex.optimizers['apply_updates'], _ = boml.BOMLOptSGD(learning_rate=lr0).minimize(ex.errors['training'],
                                                                                        var_list=ex.model.var_list)
        optim_dict = boml_ho.ll_problem(inner_objective=ex.errors['training'], learning_rate=lr0,
                                         inner_objective_optimizer=args.inner_opt,
                                         T=T, experiment=ex, var_list=ex.model.var_list, learn_lr=learn_lr,
                                         first_order=first_order)
        ex.errors['validation'] = boml.utils.cross_entropy(pred=ex.model.re_forward(ex.x_).out, label=ex.y_, method=method)
        boml_ho.ul_problem(outer_objective=ex.errors['validation'], meta_learning_rate=mlr0, inner_grad=optim_dict,
                            outer_objective_optimizer=args.outer_opt,mlr_decay=mlr_decay,
                            meta_param=tf.get_collection(boml.extension.GraphKeys.METAPARAMETERS))
    meta_learner = boml_ho.meta_model
    meta_learning_rate = boml_ho.meta_learning_rate
    apply_updates = boml_ho.outergradient.apply_updates
    inner_objectives= boml_ho.inner_objectives
    iteration =boml_ho.innergradient.iteration
    print(boml_ho.io_opt)
    print(boml_ho.oo_opt)
    boml_ho.aggregate_all(gradient_clip=process_fn)
    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=10)
    return exs, boml_ho, saver
Example #2
0
def build(metasets,
          learn_lr,
          learn_alpha,
          learn_alpha_itr,
          gamma,
          lr0,
          MBS,
          T,
          mlr0,
          mlr_decay,
          weights_initializer,
          process_fn=None,
          alpha_itr=0.0,
          method=None,
          inner_method=None,
          outer_method=None,
          use_T=False,
          use_Warp=True,
          truncate_iter=-1):
    exs = [dl.BOMLExperiment(metasets) for _ in range(MBS)]
    boml_ho = boml.BOMLOptimizer(method=method,
                                 inner_method=inner_method,
                                 outer_method=outer_method,
                                 truncate_iter=truncate_iter,
                                 experiments=exs)

    hyper_repr_model = boml_ho.meta_learner(_input=exs[0].x,
                                            dataset=metasets,
                                            meta_model='V1',
                                            name=method,
                                            use_T=use_T,
                                            use_Warp=use_Warp)

    for k, ex in enumerate(exs):
        repr_out = hyper_repr_model.re_forward(ex.x).out
        repr_out_val = hyper_repr_model.re_forward(ex.x_).out
        ex.model = boml_ho.base_learner(
            _input=repr_out,
            meta_learner=hyper_repr_model,
            weights_initializer=weights_initializer,
            name='Classifier_%s' % k)

        ex.errors['training'] = boml.utils.cross_entropy(pred=ex.model.out,
                                                         label=ex.y,
                                                         method=method)
        ex.errors['validation'] = boml.utils.cross_entropy(
            label=ex.y_,
            pred=ex.model.re_forward(repr_out_val).out,
            method=method)
        ex.scores['accuracy'] = tf.reduce_mean(tf.cast(
            tf.equal(tf.argmax(ex.y, 1), tf.argmax(ex.model.out, 1)),
            tf.float32),
                                               name='accuracy')
        inner_objective = ex.errors['training']
        ex.optimizers['apply_updates'], _ = boml.BOMLOptSGD(
            learning_rate=lr0).minimize(ex.errors['training'],
                                        var_list=ex.model.var_list)
        optim_dict = boml_ho.ll_problem(
            inner_objective=inner_objective,
            learning_rate=lr0,
            inner_objective_optimizer=args.inner_opt,
            outer_objective=ex.errors['validation'],
            alpha_init=alpha_itr,
            T=T,
            experiment=ex,
            gamma=gamma,
            learn_lr=learn_lr,
            learn_alpha_itr=learn_alpha_itr,
            learn_alpha=learn_alpha,
            var_list=ex.model.var_list)

        boml_ho.ul_problem(outer_objective=ex.errors['validation'],
                           inner_grad=optim_dict,
                           outer_objective_optimizer=args.outer_opt,
                           meta_learning_rate=mlr0,
                           mlr_decay=mlr_decay,
                           meta_param=tf.get_collection(
                               boml.extension.GraphKeys.METAPARAMETERS))
    meta_learning_rate = boml_ho.meta_learning_rate
    apply_updates = boml_ho.outergradient.apply_updates
    inner_objectives = boml_ho.inner_objectives
    iteration = boml_ho.innergradient.iteration
    print(boml_ho.io_opt)
    print(boml_ho.oo_opt)
    print(boml_ho.io_opt.learning_rate)
    print(boml_ho.io_opt.learning_rate_tensor)
    print(boml_ho.io_opt.optimizer_params_tensor)
    boml_ho.aggregate_all(gradient_clip=process_fn)
    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES),
                           max_to_keep=10)
    return exs, boml_ho, saver