def train(metasets, ex_name, hyper_repr_model_builder, classifier_builder=None, saver=None, seed=0, MBS=4, available_devices=('/gpu:0', '/gpu:1'), mlr0=.001, mlr_decay=1.e-5, T=4, n_episodes_testing=600, print_every=1000, patience=40, restore_model=False, lr=0.1, learn_lr=True, process_fn=None): """ Function for training an hyper-representation network. :param metasets: Datasets of MetaDatasets :param ex_name: name of the experiment :param hyper_repr_model_builder: builder for the representation model :param classifier_builder: optional builder for classifier model (if None then builds a linear model) :param saver: experiment_manager.Saver object :param seed: :param MBS: meta-batch size :param available_devices: distribute the computation among different GPUS! :param mlr0: initial meta learning rate :param mlr_decay: :param T: number of gradient steps for training ground models :param n_episodes_testing: :param print_every: :param patience: :param restore_model: :param lr: initial ground models learning rate :param learn_lr: True for optimizing the ground models learning rate :param process_fn: optinal hypergradient process function (like gradient clipping) :return: tuple: the saver object, the hyper-representation model and the list of experiments objects """ if saver is None: saver = SAVER_EXP(metasets) T, ss, n_episodes_testing = setup(T, seed, n_episodes_testing, MBS) exs = [em.SLExperiment(metasets) for _ in range(MBS)] hyper_repr_model = hyper_repr_model_builder(exs[0].x, name=ex_name) if classifier_builder is None: classifier_builder = lambda inp, name: models.FeedForwardNet( inp, metasets.train.dim_target, name=name) io_optim, gs, meta_lr, oo_optim, farho = _optimizers( lr, mlr0, mlr_decay, learn_lr) for k, ex in enumerate(exs): with tf.device(available_devices[k % len(available_devices)]): ex.model = classifier_builder( hyper_repr_model.for_input(ex.x).out, 'Classifier_%s' % k) ex.errors['training'] = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=ex.y, logits=ex.model.out)) ex.errors['validation'] = ex.errors['training'] ex.scores['accuracy'] = tf.reduce_mean(tf.cast( tf.equal(tf.argmax(ex.y, 1), tf.argmax(ex.model.out, 1)), tf.float32), name='accuracy') optim_dict = farho.inner_problem(ex.errors['training'], io_optim, var_list=ex.model.var_list) farho.outer_problem(ex.errors['validation'], optim_dict, oo_optim, global_step=gs) farho.finalize(process_fn=process_fn) feed_dicts, just_train_on_dataset, mean_acc_on, cond = _helper_function( exs, n_episodes_testing, MBS, ss, farho, T) rand = em.get_rand_state(0) with saver.record(*_records(metasets, saver, hyper_repr_model, cond, ss, mean_acc_on, ex_name, meta_lr), where='far', every=print_every, append_string=ex_name): tf.global_variables_initializer().run() if restore_model: saver.restore_model(hyper_repr_model) # ADD ONLY TESTING for _ in cond.early_stopping_sv(saver, patience): trfd, vfd = feed_dicts( metasets.train.generate_batch(MBS, rand=rand)) farho.run( T[0], trfd, vfd ) # one iteration of optimization of representation variables (hyperparameters) return saver, hyper_repr_model, exs
def build(metasets, hyper_model_builder, learn_lr, lr0, MBS, mlr0, mlr_decay, batch_norm_before_classifier, weights_initializer, process_fn=None): exs = [em.SLExperiment(metasets) for _ in range(MBS)] hyper_repr_model = hyper_model_builder(exs[0].x, 'HyperRepr') if learn_lr: lr = far.get_hyperparameter('lr', lr0) else: lr = tf.constant(lr0, name='lr') gs = tf.get_variable('global_step', initializer=0, trainable=False) meta_lr = tf.train.inverse_time_decay(mlr0, gs, decay_steps=1., decay_rate=mlr_decay) io_opt = far.GradientDescentOptimizer(lr) oo_opt = tf.train.AdamOptimizer(meta_lr) far_ho = far.HyperOptimizer() for k, ex in enumerate(exs): # print(k) # DEBUG with tf.device(available_devices[k % len(available_devices)]): repr_out = hyper_repr_model.for_input(ex.x).out other_train_vars = [] if batch_norm_before_classifier: batch_mean, batch_var = tf.nn.moments(repr_out, [0]) scale = tf.Variable(tf.ones_like(repr_out[0])) beta = tf.Variable(tf.zeros_like(repr_out[0])) other_train_vars.append(scale) other_train_vars.append(beta) repr_out = tf.nn.batch_normalization(repr_out, batch_mean, batch_var, beta, scale, 1e-3) ex.model = em.models.FeedForwardNet( repr_out, metasets.train.dim_target, output_weight_initializer=weights_initializer, name='Classifier_%s' % k) ex.errors['training'] = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=ex.y, logits=ex.model.out)) ex.errors['validation'] = ex.errors['training'] ex.scores['accuracy'] = tf.reduce_mean(tf.cast( tf.equal(tf.argmax(ex.y, 1), tf.argmax(ex.model.out, 1)), tf.float32), name='accuracy') # simple training step used for testing (look ex.optimizers['ts'] = tf.train.GradientDescentOptimizer( lr).minimize(ex.errors['training'], var_list=ex.model.var_list) optim_dict = far_ho.inner_problem(ex.errors['training'], io_opt, var_list=ex.model.var_list + other_train_vars) far_ho.outer_problem(ex.errors['validation'], optim_dict, oo_opt, hyper_list=tf.get_collection( far.GraphKeys.HYPERPARAMETERS), global_step=gs) far_ho.finalize(process_fn=process_fn) saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=240) return exs, far_ho, saver