def lstm_meta_learner(
        learner: Model, eigenvals_callback: TopKEigenvaluesBatched,
        configuration: TrainingConfiguration) -> MetaLearnerModel:
    # initialize weights, so in the beginning model resembles SGD
    # forget rate is close to 1 and lr is set to some constant value

    lr_bias = inverse_sigmoid(configuration.initial_lr)
    f_bias = inverse_sigmoid(0.9999)

    common_layers = get_common_lstm_model_layers(
        configuration.hidden_state_size, configuration.constant_lr_model,
        lr_bias, f_bias, 0.001)
    meta_batch_size = get_trainable_params_count(learner)

    train_meta_learner = lstm_train_meta_learner(configuration,
                                                 meta_batch_size,
                                                 common_layers)
    predict_meta_learner = lstm_predict_meta_learner(learner,
                                                     eigenvals_callback,
                                                     configuration,
                                                     meta_batch_size,
                                                     common_layers)

    return MetaLearnerModel(predict_meta_learner, train_meta_learner,
                            configuration.debug_mode)
    def __init__(self,
                 learner: Model,
                 configuration: TrainingConfiguration,
                 train_mode: bool,
                 states_outputs: List[tf.Tensor],
                 input_tensors: List[tf.Tensor],
                 intermediate_outputs: List[tf.Tensor],
                 inputs,
                 outputs,
                 name=None):
        Model.__init__(self, inputs, outputs, name)
        self.debug_mode = configuration.debug_mode
        self.learner = learner
        self.learner_grads = K.concatenate([
            K.flatten(g)
            for g in K.gradients(self.learner.total_loss,
                                 self.learner._collected_trainable_weights)
        ])
        self.learner_inputs = get_input_tensors(self.learner)
        self.intermediate_outputs = intermediate_outputs
        self.output_size = get_trainable_params_count(self.learner)
        self.hessian_eigenvals = configuration.hessian_eigenvalue_features

        # inspect how some 'random' parameters of learner learn
        n_inspect = 20
        self.inspect_parameters = np.linspace(1,
                                              self.output_size,
                                              endpoint=False,
                                              dtype=int,
                                              num=n_inspect)

        self.input_tensors = input_tensors
        self.states_outputs = states_outputs

        self.backprop_depth = configuration.backpropagation_depth
        self.train_mode = train_mode

        self.state_tensors = []

        # for BPTT (many-to-one) we need to store values of last inputs together with value of last output
        # to save memory, I use tensors with constant shape in circular way, marking current index
        self.current_backprop_index = tf.Variable(0, dtype=tf.int32)

        self.last_inputs = None
        self.initial_states = []
        self.states_history = []
        self.inputs_history = []
        self.selected_inter_outputs = None
        self.intermediate_outputs_history = []
        self.learner_weights_history = []
        self.current_output = tf.Variable(tf.zeros(shape=(1,
                                                          self.output_size)),
                                          name='current_meta_output')
def test_meta_learner(learner: Model,
                      initial_learning_rate: float = 0.05,
                      backpropagation_depth: int = 20) -> MetaLearnerModel:
    # initialize weights, so in the beginning model resembles SGD
    # forget rate is close to 1 and lr is set to some constant value

    lr_bias = inverse_sigmoid(initial_learning_rate)

    meta_batch_size = get_trainable_params_count(learner)

    train_meta_learner = test_train_meta_learner(backpropagation_depth,
                                                 meta_batch_size, lr_bias)
    predict_meta_learner = test_predict_meta_learner(learner,
                                                     backpropagation_depth,
                                                     meta_batch_size)

    return MetaLearnerModel(predict_meta_learner, train_meta_learner)
Exemple #4
0
def run_meta_learning(conf: TrainingConfiguration, x: np.ndarray,
                      y: np.ndarray):
    meta_dataset_path = conf.meta_dataset_path
    logger = conf.logger

    if not os.path.isfile(meta_dataset_path):
        factory = MetaLearningDatasetFactory(
            x=x,
            y=y,
            meta_test_ratio=conf.meta_test_class_ratio,
            learner_train_size=conf.learner_train_size,
            learner_test_size=conf.learner_test_size,
            classes_per_learner_set=conf.classes_per_learner_set,
            n_train_sets=conf.n_train_sets,
            n_test_sets=conf.n_test_sets,
            logger=logger)

        meta_dataset = factory.get()
        logger.info("Saving generated dataset to {}".format(meta_dataset_path))
        meta_dataset.save(meta_dataset_path)
    else:
        logger.info("Loading previously generated dataset from {}".format(
            meta_dataset_path))
        meta_dataset = load_meta_dataset(meta_dataset_path, x)

    def learner_factory():
        return build_simple_cnn(cifar_input_shape,
                                conf.classes_per_learner_set)

    def meta_learner_factory(learner: Model,
                             eigenvals_callback: TopKEigenvaluesBatched):
        return lstm_meta_learner(learner=learner,
                                 eigenvals_callback=eigenvals_callback,
                                 configuration=conf)

    # build dummy learner/meta-learner just to display summary
    dummy_learner = learner_factory()
    n_params = get_trainable_params_count(dummy_learner)

    dummy_learner.compile(loss='categorical_crossentropy',
                          optimizer=SGD(lr=0.0),
                          metrics=['accuracy'])

    logger.info("Using Learner with {} parameters".format(n_params))
    dummy_learner.summary()

    log_dir = os.environ['LOG_DIR']

    meta_learning_task = MetaLearningTask(
        configuration=conf,
        task_checkpoint_path=os.path.join(log_dir, 'checkpoint.txt'),
        meta_dataset=meta_dataset,
        learner_factory=learner_factory,
        meta_learner_factory=meta_learner_factory,
        training_history_path=os.path.join(log_dir,
                                           "meta_training_history.txt"),
        meta_learner_weights_path=os.path.join(log_dir, "meta_weights.h5"),
        meta_learner_weights_history_dir=os.path.join(log_dir,
                                                      "meta_weights_history"),
        best_meta_learner_weights_path=os.path.join(log_dir,
                                                    "meta_weights_best.h5"))

    meta_learning_task.meta_train(n_meta_epochs=conf.n_meta_epochs,
                                  meta_early_stopping=conf.meta_early_stopping,
                                  n_learner_batches=conf.n_learner_batches,
                                  meta_batch_size=conf.meta_batch_size,
                                  n_meta_train_steps=conf.n_train_sets //
                                  conf.meta_batch_size,
                                  n_meta_valid_steps=conf.n_meta_valid_steps,
                                  learner_batch_size=conf.learner_batch_size)
Exemple #5
0
def gradient_check(meta_model: MetaLearnerModel,
                   training_sample: MetaTrainingSample,
                   logger: Logger,
                   epsilon: float = 10e-7) -> bool:
    """
    Performs gradient check on a single meta-training-sample.
    Warning: This method is very slow for big models!
    :param meta_model: MetaLearnerModel
    :param training_sample: training sample to gradient-check
    :param logger: Logger instance
    :param epsilon: epsilon factor used in gradient checking
    :return: True if gradient check passes, otherwise False
    """
    if training_sample.final_output is None:
        raise ValueError("For gradient check, 'final_output' must not be None")
    if training_sample.learner_training_batches is None:
        raise ValueError(
            "For gradient check, 'learner_training_batches' must not be None")
    if training_sample.learner_validation_batch is None:
        raise ValueError(
            "For gradient check, 'learner_validation_batch' must not be None")
    if training_sample.initial_learner_weights is None:
        raise ValueError(
            "For gradient check, 'initial_learner_weights' must not be None")

    state_tensors = meta_model.predict_model.state_tensors
    input_tensors = get_input_tensors(meta_model.train_model)
    learner = meta_model.predict_model.learner

    sess = K.get_session()

    # first step is to evaluate gradients of meta-learner parameters using our method
    # to evaluate gradients, I use 'train_model' version of meta-learner

    # initialize meta-learner (train) states
    assert len(state_tensors) == len(training_sample.initial_states)
    feed_dict = dict(
        zip(meta_model.states_placeholder, training_sample.initial_states))
    sess.run(meta_model.init_train_states_updates, feed_dict=feed_dict)

    # standardize input for current meta-training sample
    inputs = standardize_predict_inputs(meta_model.train_model,
                                        training_sample.inputs)

    # compute gradients on current meta-learner parameters and training sample
    feed_dict = dict(zip(input_tensors, inputs))
    feed_dict[
        meta_model.learner_grad_placeholder] = training_sample.learner_grads

    # our method of computation of meta-learner gradients - this is what i want to check here for being correct
    evaluation = sess.run(fetches=meta_model.chained_grads,
                          feed_dict=feed_dict)
    evaluated_meta_grads = np.concatenate(
        [grad.flatten() for grad in evaluation])

    # gradient check for each meta-learner weight
    # for gradient checking i use 'predict_model' version of meta-learner (which is used for training Learner)
    n_meta_learner_params = get_trainable_params_count(meta_model.train_model)
    approximated_meta_grads = np.zeros(shape=n_meta_learner_params)

    valid_x, valid_y = training_sample.learner_validation_batch
    learner_valid_ins = standardize_train_inputs(learner, valid_x, valid_y)

    # tensors used for updating meta-learner weights
    trainable_meta_weights = sess.run(
        meta_model.predict_model.trainable_weights)
    meta_weights_placeholder = [
        tf.placeholder(shape=w.get_shape(), dtype=tf.float32)
        for w in meta_model.predict_model.trainable_weights
    ]
    meta_weights_updates = [
        tf.assign(w, new_w)
        for w, new_w in zip(meta_model.predict_model.trainable_weights,
                            meta_weights_placeholder)
    ]

    def calculate_loss(new_weights):
        # update weights of meta-learner ('predict_model')
        f_dict = dict(zip(meta_weights_placeholder, new_weights))
        sess.run(meta_weights_updates, feed_dict=f_dict)

        # initialize learner parameters
        learner.set_weights(training_sample.initial_learner_weights)

        # initialize meta-learner (predict) states
        f_dict = dict(
            zip(meta_model.states_placeholder, training_sample.initial_states))
        sess.run(meta_model.init_predict_states_updates, feed_dict=f_dict)

        # train learner using same batches as in the sample (meta 'predict_model' is used here)
        for x, y in training_sample.learner_training_batches:
            learner.train_on_batch(x, y)

        # calculate new learner loss on validation set after training
        f_dict = dict(
            zip(meta_model.predict_model.learner_inputs, learner_valid_ins))
        new_loss = sess.run(fetches=[learner.total_loss], feed_dict=f_dict)[0]

        return new_loss

    grad_ind = 0
    for i, w in enumerate(trainable_meta_weights):
        # set meta-learner ('predict_model') params to new, where only one weight is changed by some epsilon
        if w.ndim == 2:
            for j in range(w.shape[0]):
                for k in range(w.shape[1]):
                    changed_meta_learner_weights = [
                        w.copy() for w in trainable_meta_weights
                    ]
                    changed_meta_learner_weights[i][j][k] += epsilon
                    loss1 = calculate_loss(changed_meta_learner_weights)
                    changed_meta_learner_weights[i][j][k] -= 2 * epsilon
                    loss2 = calculate_loss(changed_meta_learner_weights)
                    approximated_meta_grads[grad_ind] = (loss1 -
                                                         loss2) / (2 * epsilon)
                    grad_ind += 1
        elif w.ndim == 1:
            for j in range(w.shape[0]):
                changed_meta_learner_weights = [
                    w.copy() for w in trainable_meta_weights
                ]
                changed_meta_learner_weights[i][j] += epsilon
                loss1 = calculate_loss(changed_meta_learner_weights)
                changed_meta_learner_weights[i][j] -= 2 * epsilon
                loss2 = calculate_loss(changed_meta_learner_weights)
                approximated_meta_grads[grad_ind] = (loss1 - loss2) / (2 *
                                                                       epsilon)
                grad_ind += 1
        else:
            raise ValueError(
                "Only weights with ndim == 1 or ndim == 2 are supported in grad check"
            )

    approximated_grad_diff = np.linalg.norm(approximated_meta_grads - evaluated_meta_grads) / \
                             (np.linalg.norm(approximated_meta_grads) + np.linalg.norm(evaluated_meta_grads))

    if approximated_grad_diff > epsilon:
        logger.error("GRAD-CHECK: (epsilon={}, dist={})!".format(
            epsilon, approximated_grad_diff))
        return False
    else:
        logger.debug("Grad-Check passed. (epsilon={}, dist={})".format(
            epsilon, approximated_grad_diff))

    return True
Exemple #6
0
    def meta_train(self, n_meta_epochs: int, meta_batch_size: int,
                   n_learner_batches: int, meta_early_stopping: int,
                   n_meta_train_steps: int, n_meta_valid_steps: int,
                   learner_batch_size: int):
        """
        Trains meta-learning model for a few epochs
        :param n_meta_epochs: number of meta-epochs
        :param meta_batch_size: size of meta-batch of Learners per one meta-optimizer weight update
        :param n_learner_batches: number of training batches
        :param meta_early_stopping: early stopping patience for Learner training
        :param n_meta_train_steps: number of meta-training batches per epoch
        :param n_meta_valid_steps: number of meta-validation batches per epoch
        :param learner_batch_size: batch size when training Learner
        """
        lr = self.initial_meta_lr

        self.logger.info("Starting with meta-learning rate: {}".format(lr))

        epochs_with_no_gain = 0
        epochs_with_no_gain_lr = 0

        self._compile(lr, self.starting_epoch, learner_batch_size)

        if self.starting_epoch == 0:
            n_params = get_trainable_params_count(
                self.meta_learner.train_model)

            self.logger.info(
                "Using Meta-Learner with {} parameters".format(n_params))
            self.meta_learner.train_model.summary()

            # save initial meta-learner weights
            self.meta_learner.train_model.save(
                os.path.join(self.meta_learner_weights_history_dir,
                             'meta_weights_epoch_0.h5'))

            # copy training configuration to log dir
            copyfile(
                os.path.join(os.environ['CONF_DIR'],
                             'training_configuration.yml'),
                os.path.join(os.environ['LOG_DIR'],
                             'training_configuration.yml'))
            self.logger.info("Validating Meta-Learner on start...")
            self.meta_validate_epoch(n_meta_valid_steps=n_meta_valid_steps,
                                     n_learner_batches=n_learner_batches,
                                     learner_batch_size=learner_batch_size,
                                     epoch_number=-1)

        for i in tqdm(range(n_meta_epochs - self.starting_epoch),
                      desc='Running Meta-Training'):
            epoch = self.starting_epoch + i
            epoch_start = time.time()

            # reset backend session each epoch to avoid memory leaks etc
            self._compile(lr, epoch, learner_batch_size)

            if self.best_loss is None:
                self.logger.info("Starting meta-epoch {:d}".format(epoch + 1))
            else:
                self.logger.info(
                    "Starting meta-epoch {:d} (best loss: {:.5f})".format(
                        epoch + 1, self.best_loss))

            if self.optimizer_weights is not None:
                self.meta_learner.train_model.optimizer.set_weights(
                    self.optimizer_weights)

            if epochs_with_no_gain_lr >= self.configuration.meta_lr_early_stopping:
                new_lr = lr / self.configuration.meta_lr_divisor

                self.logger.info(
                    "Changing meta-learning rate from {} to {} (meta-epoch {})"
                    .format(lr, new_lr, epoch + 1))

                lr = new_lr
                self.meta_learner.train_model.optimizer.update_lr(lr)
                epochs_with_no_gain_lr = 0

            self.meta_train_epoch(meta_batch_size=meta_batch_size,
                                  n_learner_batches=n_learner_batches,
                                  n_meta_train_steps=n_meta_train_steps,
                                  learner_batch_size=learner_batch_size,
                                  epoch_number=epoch)

            self.meta_learner.train_model.save(
                os.path.join(self.meta_learner_weights_history_dir,
                             'meta_weights_epoch_{}.h5'.format(epoch + 1)))

            valid_metrics = self.meta_validate_epoch(
                n_learner_batches=n_learner_batches,
                n_meta_valid_steps=n_meta_valid_steps,
                learner_batch_size=learner_batch_size,
                epoch_number=epoch)

            epoch_duration = round(time.time() - epoch_start, 2)
            self.logger.info("Duration of epoch {}: {} s".format(
                epoch + 1, epoch_duration))
            self.logger.info("*" * 50)

            loss_ind = self.learner.metrics_names.index('loss')
            if self.best_loss is None or valid_metrics[
                    loss_ind] < self.best_loss:
                self.best_loss = valid_metrics[loss_ind]
                epochs_with_no_gain = 0
                epochs_with_no_gain_lr = 0
                self.meta_learner.train_model.save(
                    self.best_meta_learner_weights_path)
            else:
                epochs_with_no_gain += 1
                epochs_with_no_gain_lr += 1

            with open(self.task_checkpoint_path, 'w') as f:
                f.write(str(epoch + 1))
                f.write('\n')
                f.write(str(self.best_loss))

            with open(self.lr_history_path, 'a') as f:
                f.write(str(lr))
                f.write('\n')

            if epochs_with_no_gain >= meta_early_stopping:
                self.logger.info(
                    "Early stopping after {} meta-epochs".format(epoch + 1))
                self.logger.info("*" * 30)
                break