def build(metasets, learn_lr, lr0, MBS, T, mlr0,mlr_decay=1.e-5, process_fn=None, method='MetaInit',inner_method='Simple', outer_method='Simple', use_T=False, use_Warp=False, first_order=False): exs = [dl.BOMLExperiment(metasets) for _ in range(MBS)] boml_ho = boml.BOMLOptimizer(method=method, inner_method=inner_method, outer_method=outer_method,experiments=exs) meta_model = boml_ho.meta_learner(_input=exs[0].x, dataset=metasets, meta_model='V1', name='HyperRepr', use_T=use_T,use_Warp=use_Warp) for k, ex in enumerate(exs): ex.model = boml_ho.base_learner(_input=ex.x, meta_learner=meta_model, name='Task_Net_%s' % k) ex.errors['training'] = boml.utils.cross_entropy(pred=ex.model.out, label=ex.y, method=method) ex.scores['accuracy'] = boml.utils.classification_acc(pred=ex.model.out,label=ex.y) ex.optimizers['apply_updates'], _ = boml.BOMLOptSGD(learning_rate=lr0).minimize(ex.errors['training'], var_list=ex.model.var_list) optim_dict = boml_ho.ll_problem(inner_objective=ex.errors['training'], learning_rate=lr0, inner_objective_optimizer=args.inner_opt, T=T, experiment=ex, var_list=ex.model.var_list, learn_lr=learn_lr, first_order=first_order) ex.errors['validation'] = boml.utils.cross_entropy(pred=ex.model.re_forward(ex.x_).out, label=ex.y_, method=method) boml_ho.ul_problem(outer_objective=ex.errors['validation'], meta_learning_rate=mlr0, inner_grad=optim_dict, outer_objective_optimizer=args.outer_opt,mlr_decay=mlr_decay, meta_param=tf.get_collection(boml.extension.GraphKeys.METAPARAMETERS)) meta_learner = boml_ho.meta_model meta_learning_rate = boml_ho.meta_learning_rate apply_updates = boml_ho.outergradient.apply_updates inner_objectives= boml_ho.inner_objectives iteration =boml_ho.innergradient.iteration print(boml_ho.io_opt) print(boml_ho.oo_opt) boml_ho.aggregate_all(gradient_clip=process_fn) saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=10) return exs, boml_ho, saver
def build(metasets, learn_lr, learn_alpha, learn_alpha_itr, gamma, lr0, MBS, T, mlr0, mlr_decay, weights_initializer, process_fn=None, alpha_itr=0.0, method=None, inner_method=None, outer_method=None, use_T=False, use_Warp=True, truncate_iter=-1): exs = [dl.BOMLExperiment(metasets) for _ in range(MBS)] boml_ho = boml.BOMLOptimizer(method=method, inner_method=inner_method, outer_method=outer_method, truncate_iter=truncate_iter, experiments=exs) hyper_repr_model = boml_ho.meta_learner(_input=exs[0].x, dataset=metasets, meta_model='V1', name=method, use_T=use_T, use_Warp=use_Warp) for k, ex in enumerate(exs): repr_out = hyper_repr_model.re_forward(ex.x).out repr_out_val = hyper_repr_model.re_forward(ex.x_).out ex.model = boml_ho.base_learner( _input=repr_out, meta_learner=hyper_repr_model, weights_initializer=weights_initializer, name='Classifier_%s' % k) ex.errors['training'] = boml.utils.cross_entropy(pred=ex.model.out, label=ex.y, method=method) ex.errors['validation'] = boml.utils.cross_entropy( label=ex.y_, pred=ex.model.re_forward(repr_out_val).out, method=method) ex.scores['accuracy'] = tf.reduce_mean(tf.cast( tf.equal(tf.argmax(ex.y, 1), tf.argmax(ex.model.out, 1)), tf.float32), name='accuracy') inner_objective = ex.errors['training'] ex.optimizers['apply_updates'], _ = boml.BOMLOptSGD( learning_rate=lr0).minimize(ex.errors['training'], var_list=ex.model.var_list) optim_dict = boml_ho.ll_problem( inner_objective=inner_objective, learning_rate=lr0, inner_objective_optimizer=args.inner_opt, outer_objective=ex.errors['validation'], alpha_init=alpha_itr, T=T, experiment=ex, gamma=gamma, learn_lr=learn_lr, learn_alpha_itr=learn_alpha_itr, learn_alpha=learn_alpha, var_list=ex.model.var_list) boml_ho.ul_problem(outer_objective=ex.errors['validation'], inner_grad=optim_dict, outer_objective_optimizer=args.outer_opt, meta_learning_rate=mlr0, mlr_decay=mlr_decay, meta_param=tf.get_collection( boml.extension.GraphKeys.METAPARAMETERS)) meta_learning_rate = boml_ho.meta_learning_rate apply_updates = boml_ho.outergradient.apply_updates inner_objectives = boml_ho.inner_objectives iteration = boml_ho.innergradient.iteration print(boml_ho.io_opt) print(boml_ho.oo_opt) print(boml_ho.io_opt.learning_rate) print(boml_ho.io_opt.learning_rate_tensor) print(boml_ho.io_opt.optimizer_params_tensor) boml_ho.aggregate_all(gradient_clip=process_fn) saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=10) return exs, boml_ho, saver