Пример #1
0
def train(metasets,
          ex_name,
          hyper_repr_model_builder,
          classifier_builder=None,
          saver=None,
          seed=0,
          MBS=4,
          available_devices=('/gpu:0', '/gpu:1'),
          mlr0=.001,
          mlr_decay=1.e-5,
          T=4,
          n_episodes_testing=600,
          print_every=1000,
          patience=40,
          restore_model=False,
          lr=0.1,
          learn_lr=True,
          process_fn=None):
    """
    Function for training an hyper-representation network.

    :param metasets: Datasets of MetaDatasets
    :param ex_name: name of the experiment
    :param hyper_repr_model_builder: builder for the representation model
    :param classifier_builder: optional builder for classifier model (if None then builds a linear model)
    :param saver: experiment_manager.Saver object
    :param seed:
    :param MBS: meta-batch size
    :param available_devices: distribute the computation among different GPUS!
    :param mlr0: initial meta learning rate
    :param mlr_decay:
    :param T: number of gradient steps for training ground models
    :param n_episodes_testing:
    :param print_every:
    :param patience:
    :param restore_model:
    :param lr: initial ground models learning rate
    :param learn_lr: True for optimizing the ground models learning rate
    :param process_fn: optinal hypergradient process function (like gradient clipping)

    :return: tuple: the saver object, the hyper-representation model and the list of experiments objects
    """
    if saver is None:
        saver = SAVER_EXP(metasets)

    T, ss, n_episodes_testing = setup(T, seed, n_episodes_testing, MBS)
    exs = [em.SLExperiment(metasets) for _ in range(MBS)]

    hyper_repr_model = hyper_repr_model_builder(exs[0].x, name=ex_name)
    if classifier_builder is None:
        classifier_builder = lambda inp, name: models.FeedForwardNet(
            inp, metasets.train.dim_target, name=name)

    io_optim, gs, meta_lr, oo_optim, farho = _optimizers(
        lr, mlr0, mlr_decay, learn_lr)

    for k, ex in enumerate(exs):
        with tf.device(available_devices[k % len(available_devices)]):
            ex.model = classifier_builder(
                hyper_repr_model.for_input(ex.x).out, 'Classifier_%s' % k)
            ex.errors['training'] = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(labels=ex.y,
                                                        logits=ex.model.out))
            ex.errors['validation'] = ex.errors['training']
            ex.scores['accuracy'] = tf.reduce_mean(tf.cast(
                tf.equal(tf.argmax(ex.y, 1), tf.argmax(ex.model.out, 1)),
                tf.float32),
                                                   name='accuracy')

            optim_dict = farho.inner_problem(ex.errors['training'],
                                             io_optim,
                                             var_list=ex.model.var_list)
            farho.outer_problem(ex.errors['validation'],
                                optim_dict,
                                oo_optim,
                                global_step=gs)

    farho.finalize(process_fn=process_fn)

    feed_dicts, just_train_on_dataset, mean_acc_on, cond = _helper_function(
        exs, n_episodes_testing, MBS, ss, farho, T)

    rand = em.get_rand_state(0)

    with saver.record(*_records(metasets, saver, hyper_repr_model, cond, ss,
                                mean_acc_on, ex_name, meta_lr),
                      where='far',
                      every=print_every,
                      append_string=ex_name):
        tf.global_variables_initializer().run()
        if restore_model:
            saver.restore_model(hyper_repr_model)
        # ADD ONLY TESTING
        for _ in cond.early_stopping_sv(saver, patience):
            trfd, vfd = feed_dicts(
                metasets.train.generate_batch(MBS, rand=rand))

            farho.run(
                T[0], trfd, vfd
            )  # one iteration of optimization of representation variables (hyperparameters)

    return saver, hyper_repr_model, exs
Пример #2
0
def build(metasets,
          hyper_model_builder,
          learn_lr,
          lr0,
          MBS,
          mlr0,
          mlr_decay,
          batch_norm_before_classifier,
          weights_initializer,
          process_fn=None):
    exs = [em.SLExperiment(metasets) for _ in range(MBS)]

    hyper_repr_model = hyper_model_builder(exs[0].x, 'HyperRepr')

    if learn_lr:
        lr = far.get_hyperparameter('lr', lr0)
    else:
        lr = tf.constant(lr0, name='lr')

    gs = tf.get_variable('global_step', initializer=0, trainable=False)
    meta_lr = tf.train.inverse_time_decay(mlr0,
                                          gs,
                                          decay_steps=1.,
                                          decay_rate=mlr_decay)

    io_opt = far.GradientDescentOptimizer(lr)
    oo_opt = tf.train.AdamOptimizer(meta_lr)
    far_ho = far.HyperOptimizer()

    for k, ex in enumerate(exs):
        # print(k)  # DEBUG
        with tf.device(available_devices[k % len(available_devices)]):
            repr_out = hyper_repr_model.for_input(ex.x).out

            other_train_vars = []
            if batch_norm_before_classifier:
                batch_mean, batch_var = tf.nn.moments(repr_out, [0])
                scale = tf.Variable(tf.ones_like(repr_out[0]))
                beta = tf.Variable(tf.zeros_like(repr_out[0]))
                other_train_vars.append(scale)
                other_train_vars.append(beta)
                repr_out = tf.nn.batch_normalization(repr_out, batch_mean,
                                                     batch_var, beta, scale,
                                                     1e-3)

            ex.model = em.models.FeedForwardNet(
                repr_out,
                metasets.train.dim_target,
                output_weight_initializer=weights_initializer,
                name='Classifier_%s' % k)

            ex.errors['training'] = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(labels=ex.y,
                                                        logits=ex.model.out))
            ex.errors['validation'] = ex.errors['training']
            ex.scores['accuracy'] = tf.reduce_mean(tf.cast(
                tf.equal(tf.argmax(ex.y, 1), tf.argmax(ex.model.out, 1)),
                tf.float32),
                                                   name='accuracy')

            # simple training step used for testing (look
            ex.optimizers['ts'] = tf.train.GradientDescentOptimizer(
                lr).minimize(ex.errors['training'], var_list=ex.model.var_list)

            optim_dict = far_ho.inner_problem(ex.errors['training'],
                                              io_opt,
                                              var_list=ex.model.var_list +
                                              other_train_vars)
            far_ho.outer_problem(ex.errors['validation'],
                                 optim_dict,
                                 oo_opt,
                                 hyper_list=tf.get_collection(
                                     far.GraphKeys.HYPERPARAMETERS),
                                 global_step=gs)

    far_ho.finalize(process_fn=process_fn)
    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES),
                           max_to_keep=240)
    return exs, far_ho, saver