def lstm_meta_learner( learner: Model, eigenvals_callback: TopKEigenvaluesBatched, configuration: TrainingConfiguration) -> MetaLearnerModel: # initialize weights, so in the beginning model resembles SGD # forget rate is close to 1 and lr is set to some constant value lr_bias = inverse_sigmoid(configuration.initial_lr) f_bias = inverse_sigmoid(0.9999) common_layers = get_common_lstm_model_layers( configuration.hidden_state_size, configuration.constant_lr_model, lr_bias, f_bias, 0.001) meta_batch_size = get_trainable_params_count(learner) train_meta_learner = lstm_train_meta_learner(configuration, meta_batch_size, common_layers) predict_meta_learner = lstm_predict_meta_learner(learner, eigenvals_callback, configuration, meta_batch_size, common_layers) return MetaLearnerModel(predict_meta_learner, train_meta_learner, configuration.debug_mode)
def __init__(self, learner: Model, configuration: TrainingConfiguration, train_mode: bool, states_outputs: List[tf.Tensor], input_tensors: List[tf.Tensor], intermediate_outputs: List[tf.Tensor], inputs, outputs, name=None): Model.__init__(self, inputs, outputs, name) self.debug_mode = configuration.debug_mode self.learner = learner self.learner_grads = K.concatenate([ K.flatten(g) for g in K.gradients(self.learner.total_loss, self.learner._collected_trainable_weights) ]) self.learner_inputs = get_input_tensors(self.learner) self.intermediate_outputs = intermediate_outputs self.output_size = get_trainable_params_count(self.learner) self.hessian_eigenvals = configuration.hessian_eigenvalue_features # inspect how some 'random' parameters of learner learn n_inspect = 20 self.inspect_parameters = np.linspace(1, self.output_size, endpoint=False, dtype=int, num=n_inspect) self.input_tensors = input_tensors self.states_outputs = states_outputs self.backprop_depth = configuration.backpropagation_depth self.train_mode = train_mode self.state_tensors = [] # for BPTT (many-to-one) we need to store values of last inputs together with value of last output # to save memory, I use tensors with constant shape in circular way, marking current index self.current_backprop_index = tf.Variable(0, dtype=tf.int32) self.last_inputs = None self.initial_states = [] self.states_history = [] self.inputs_history = [] self.selected_inter_outputs = None self.intermediate_outputs_history = [] self.learner_weights_history = [] self.current_output = tf.Variable(tf.zeros(shape=(1, self.output_size)), name='current_meta_output')
def test_meta_learner(learner: Model, initial_learning_rate: float = 0.05, backpropagation_depth: int = 20) -> MetaLearnerModel: # initialize weights, so in the beginning model resembles SGD # forget rate is close to 1 and lr is set to some constant value lr_bias = inverse_sigmoid(initial_learning_rate) meta_batch_size = get_trainable_params_count(learner) train_meta_learner = test_train_meta_learner(backpropagation_depth, meta_batch_size, lr_bias) predict_meta_learner = test_predict_meta_learner(learner, backpropagation_depth, meta_batch_size) return MetaLearnerModel(predict_meta_learner, train_meta_learner)
def run_meta_learning(conf: TrainingConfiguration, x: np.ndarray, y: np.ndarray): meta_dataset_path = conf.meta_dataset_path logger = conf.logger if not os.path.isfile(meta_dataset_path): factory = MetaLearningDatasetFactory( x=x, y=y, meta_test_ratio=conf.meta_test_class_ratio, learner_train_size=conf.learner_train_size, learner_test_size=conf.learner_test_size, classes_per_learner_set=conf.classes_per_learner_set, n_train_sets=conf.n_train_sets, n_test_sets=conf.n_test_sets, logger=logger) meta_dataset = factory.get() logger.info("Saving generated dataset to {}".format(meta_dataset_path)) meta_dataset.save(meta_dataset_path) else: logger.info("Loading previously generated dataset from {}".format( meta_dataset_path)) meta_dataset = load_meta_dataset(meta_dataset_path, x) def learner_factory(): return build_simple_cnn(cifar_input_shape, conf.classes_per_learner_set) def meta_learner_factory(learner: Model, eigenvals_callback: TopKEigenvaluesBatched): return lstm_meta_learner(learner=learner, eigenvals_callback=eigenvals_callback, configuration=conf) # build dummy learner/meta-learner just to display summary dummy_learner = learner_factory() n_params = get_trainable_params_count(dummy_learner) dummy_learner.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.0), metrics=['accuracy']) logger.info("Using Learner with {} parameters".format(n_params)) dummy_learner.summary() log_dir = os.environ['LOG_DIR'] meta_learning_task = MetaLearningTask( configuration=conf, task_checkpoint_path=os.path.join(log_dir, 'checkpoint.txt'), meta_dataset=meta_dataset, learner_factory=learner_factory, meta_learner_factory=meta_learner_factory, training_history_path=os.path.join(log_dir, "meta_training_history.txt"), meta_learner_weights_path=os.path.join(log_dir, "meta_weights.h5"), meta_learner_weights_history_dir=os.path.join(log_dir, "meta_weights_history"), best_meta_learner_weights_path=os.path.join(log_dir, "meta_weights_best.h5")) meta_learning_task.meta_train(n_meta_epochs=conf.n_meta_epochs, meta_early_stopping=conf.meta_early_stopping, n_learner_batches=conf.n_learner_batches, meta_batch_size=conf.meta_batch_size, n_meta_train_steps=conf.n_train_sets // conf.meta_batch_size, n_meta_valid_steps=conf.n_meta_valid_steps, learner_batch_size=conf.learner_batch_size)
def gradient_check(meta_model: MetaLearnerModel, training_sample: MetaTrainingSample, logger: Logger, epsilon: float = 10e-7) -> bool: """ Performs gradient check on a single meta-training-sample. Warning: This method is very slow for big models! :param meta_model: MetaLearnerModel :param training_sample: training sample to gradient-check :param logger: Logger instance :param epsilon: epsilon factor used in gradient checking :return: True if gradient check passes, otherwise False """ if training_sample.final_output is None: raise ValueError("For gradient check, 'final_output' must not be None") if training_sample.learner_training_batches is None: raise ValueError( "For gradient check, 'learner_training_batches' must not be None") if training_sample.learner_validation_batch is None: raise ValueError( "For gradient check, 'learner_validation_batch' must not be None") if training_sample.initial_learner_weights is None: raise ValueError( "For gradient check, 'initial_learner_weights' must not be None") state_tensors = meta_model.predict_model.state_tensors input_tensors = get_input_tensors(meta_model.train_model) learner = meta_model.predict_model.learner sess = K.get_session() # first step is to evaluate gradients of meta-learner parameters using our method # to evaluate gradients, I use 'train_model' version of meta-learner # initialize meta-learner (train) states assert len(state_tensors) == len(training_sample.initial_states) feed_dict = dict( zip(meta_model.states_placeholder, training_sample.initial_states)) sess.run(meta_model.init_train_states_updates, feed_dict=feed_dict) # standardize input for current meta-training sample inputs = standardize_predict_inputs(meta_model.train_model, training_sample.inputs) # compute gradients on current meta-learner parameters and training sample feed_dict = dict(zip(input_tensors, inputs)) feed_dict[ meta_model.learner_grad_placeholder] = training_sample.learner_grads # our method of computation of meta-learner gradients - this is what i want to check here for being correct evaluation = sess.run(fetches=meta_model.chained_grads, feed_dict=feed_dict) evaluated_meta_grads = np.concatenate( [grad.flatten() for grad in evaluation]) # gradient check for each meta-learner weight # for gradient checking i use 'predict_model' version of meta-learner (which is used for training Learner) n_meta_learner_params = get_trainable_params_count(meta_model.train_model) approximated_meta_grads = np.zeros(shape=n_meta_learner_params) valid_x, valid_y = training_sample.learner_validation_batch learner_valid_ins = standardize_train_inputs(learner, valid_x, valid_y) # tensors used for updating meta-learner weights trainable_meta_weights = sess.run( meta_model.predict_model.trainable_weights) meta_weights_placeholder = [ tf.placeholder(shape=w.get_shape(), dtype=tf.float32) for w in meta_model.predict_model.trainable_weights ] meta_weights_updates = [ tf.assign(w, new_w) for w, new_w in zip(meta_model.predict_model.trainable_weights, meta_weights_placeholder) ] def calculate_loss(new_weights): # update weights of meta-learner ('predict_model') f_dict = dict(zip(meta_weights_placeholder, new_weights)) sess.run(meta_weights_updates, feed_dict=f_dict) # initialize learner parameters learner.set_weights(training_sample.initial_learner_weights) # initialize meta-learner (predict) states f_dict = dict( zip(meta_model.states_placeholder, training_sample.initial_states)) sess.run(meta_model.init_predict_states_updates, feed_dict=f_dict) # train learner using same batches as in the sample (meta 'predict_model' is used here) for x, y in training_sample.learner_training_batches: learner.train_on_batch(x, y) # calculate new learner loss on validation set after training f_dict = dict( zip(meta_model.predict_model.learner_inputs, learner_valid_ins)) new_loss = sess.run(fetches=[learner.total_loss], feed_dict=f_dict)[0] return new_loss grad_ind = 0 for i, w in enumerate(trainable_meta_weights): # set meta-learner ('predict_model') params to new, where only one weight is changed by some epsilon if w.ndim == 2: for j in range(w.shape[0]): for k in range(w.shape[1]): changed_meta_learner_weights = [ w.copy() for w in trainable_meta_weights ] changed_meta_learner_weights[i][j][k] += epsilon loss1 = calculate_loss(changed_meta_learner_weights) changed_meta_learner_weights[i][j][k] -= 2 * epsilon loss2 = calculate_loss(changed_meta_learner_weights) approximated_meta_grads[grad_ind] = (loss1 - loss2) / (2 * epsilon) grad_ind += 1 elif w.ndim == 1: for j in range(w.shape[0]): changed_meta_learner_weights = [ w.copy() for w in trainable_meta_weights ] changed_meta_learner_weights[i][j] += epsilon loss1 = calculate_loss(changed_meta_learner_weights) changed_meta_learner_weights[i][j] -= 2 * epsilon loss2 = calculate_loss(changed_meta_learner_weights) approximated_meta_grads[grad_ind] = (loss1 - loss2) / (2 * epsilon) grad_ind += 1 else: raise ValueError( "Only weights with ndim == 1 or ndim == 2 are supported in grad check" ) approximated_grad_diff = np.linalg.norm(approximated_meta_grads - evaluated_meta_grads) / \ (np.linalg.norm(approximated_meta_grads) + np.linalg.norm(evaluated_meta_grads)) if approximated_grad_diff > epsilon: logger.error("GRAD-CHECK: (epsilon={}, dist={})!".format( epsilon, approximated_grad_diff)) return False else: logger.debug("Grad-Check passed. (epsilon={}, dist={})".format( epsilon, approximated_grad_diff)) return True
def meta_train(self, n_meta_epochs: int, meta_batch_size: int, n_learner_batches: int, meta_early_stopping: int, n_meta_train_steps: int, n_meta_valid_steps: int, learner_batch_size: int): """ Trains meta-learning model for a few epochs :param n_meta_epochs: number of meta-epochs :param meta_batch_size: size of meta-batch of Learners per one meta-optimizer weight update :param n_learner_batches: number of training batches :param meta_early_stopping: early stopping patience for Learner training :param n_meta_train_steps: number of meta-training batches per epoch :param n_meta_valid_steps: number of meta-validation batches per epoch :param learner_batch_size: batch size when training Learner """ lr = self.initial_meta_lr self.logger.info("Starting with meta-learning rate: {}".format(lr)) epochs_with_no_gain = 0 epochs_with_no_gain_lr = 0 self._compile(lr, self.starting_epoch, learner_batch_size) if self.starting_epoch == 0: n_params = get_trainable_params_count( self.meta_learner.train_model) self.logger.info( "Using Meta-Learner with {} parameters".format(n_params)) self.meta_learner.train_model.summary() # save initial meta-learner weights self.meta_learner.train_model.save( os.path.join(self.meta_learner_weights_history_dir, 'meta_weights_epoch_0.h5')) # copy training configuration to log dir copyfile( os.path.join(os.environ['CONF_DIR'], 'training_configuration.yml'), os.path.join(os.environ['LOG_DIR'], 'training_configuration.yml')) self.logger.info("Validating Meta-Learner on start...") self.meta_validate_epoch(n_meta_valid_steps=n_meta_valid_steps, n_learner_batches=n_learner_batches, learner_batch_size=learner_batch_size, epoch_number=-1) for i in tqdm(range(n_meta_epochs - self.starting_epoch), desc='Running Meta-Training'): epoch = self.starting_epoch + i epoch_start = time.time() # reset backend session each epoch to avoid memory leaks etc self._compile(lr, epoch, learner_batch_size) if self.best_loss is None: self.logger.info("Starting meta-epoch {:d}".format(epoch + 1)) else: self.logger.info( "Starting meta-epoch {:d} (best loss: {:.5f})".format( epoch + 1, self.best_loss)) if self.optimizer_weights is not None: self.meta_learner.train_model.optimizer.set_weights( self.optimizer_weights) if epochs_with_no_gain_lr >= self.configuration.meta_lr_early_stopping: new_lr = lr / self.configuration.meta_lr_divisor self.logger.info( "Changing meta-learning rate from {} to {} (meta-epoch {})" .format(lr, new_lr, epoch + 1)) lr = new_lr self.meta_learner.train_model.optimizer.update_lr(lr) epochs_with_no_gain_lr = 0 self.meta_train_epoch(meta_batch_size=meta_batch_size, n_learner_batches=n_learner_batches, n_meta_train_steps=n_meta_train_steps, learner_batch_size=learner_batch_size, epoch_number=epoch) self.meta_learner.train_model.save( os.path.join(self.meta_learner_weights_history_dir, 'meta_weights_epoch_{}.h5'.format(epoch + 1))) valid_metrics = self.meta_validate_epoch( n_learner_batches=n_learner_batches, n_meta_valid_steps=n_meta_valid_steps, learner_batch_size=learner_batch_size, epoch_number=epoch) epoch_duration = round(time.time() - epoch_start, 2) self.logger.info("Duration of epoch {}: {} s".format( epoch + 1, epoch_duration)) self.logger.info("*" * 50) loss_ind = self.learner.metrics_names.index('loss') if self.best_loss is None or valid_metrics[ loss_ind] < self.best_loss: self.best_loss = valid_metrics[loss_ind] epochs_with_no_gain = 0 epochs_with_no_gain_lr = 0 self.meta_learner.train_model.save( self.best_meta_learner_weights_path) else: epochs_with_no_gain += 1 epochs_with_no_gain_lr += 1 with open(self.task_checkpoint_path, 'w') as f: f.write(str(epoch + 1)) f.write('\n') f.write(str(self.best_loss)) with open(self.lr_history_path, 'a') as f: f.write(str(lr)) f.write('\n') if epochs_with_no_gain >= meta_early_stopping: self.logger.info( "Early stopping after {} meta-epochs".format(epoch + 1)) self.logger.info("*" * 30) break