def test_train_f(self): """ The real dynamics is: x -> x+1 Parametrize model f with one parameter theta, to encode the function x -> x + theta Check if the loss can be reduced """ tf.reset_default_graph() x_shape = (1, ) z_shape = x_shape n_timesteps = 5 # f adds theta z_1 = Input(shape=z_shape) with tf.variable_scope('test_train_f'): theta = tf.get_variable('theta_2', shape=(), dtype=tf.float32, initializer=tf.initializers.constant(0.0)) f_layer = Lambda( lambda x: x + theta, trainable=True, ) f_layer.trainable_weights = [theta] f = Model(z_1, f_layer(z_1), name='f') print('f.trainable_weights', f.trainable_weights) learning_rate = 0.01 train_step_counter = make_count_variable('train_step_counter', 0) predictor = ASIModel( x_shape, f=f, delta_t_bounds=(1, 1), exploration_schedule=annealing_schedules.constant_zero, schd_sampling_schedule=annealing_schedules.constant_zero, parallel_iterations=10, train_step_counter=train_step_counter, z_loss_fn=lambda a, b: tf.square(a - b), f_optimizer=tf.train.GradientDescentOptimizer(learning_rate), forgiving=True) x_observed = (np.zeros( (n_timesteps, ) + x_shape, dtype=np.float32) + np.arange( 0, n_timesteps).reshape((1, -1) + tuple([1] * len(x_shape)))) x_1 = x_observed[:, 0] batch_size = x_observed.shape[0] init_vars = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_vars) _z_hat = predictor.predict_n_steps(x_1, n=n_timesteps, sess=sess) print('_z_hat:', _z_hat) desired_z_hat = np.zeros((batch_size, n_timesteps, 1)) assert _z_hat.tolist() == desired_z_hat.tolist() predictor.train_on_trajectories(sess, x_observed, trajectory_lengths=[n_timesteps]) assert predictor.train_analytics['z_loss'] == 7.5 print('predictor.train_analytics', predictor.train_analytics) new_theta = sess.run(theta) print('new_theta', new_theta) assert np.isclose(new_theta, 0.15)
def run(self, train_iterator_init_op, valid_iterator_init_op, predictor: ASIModel, n_epochs_max, task_specific_metrics, progress_filepath, step_data_filepath, callbacks: List[TrainCallback] = None, saver_max_to_keep=5, model_path=None): """ :param train_iterator_init_op: tf-operation which initializes the train-iterator :param valid_iterator_init_op: tf-operation which initializes the valid-iterator :param predictor: Instance of ASIModel :param progress_filepath: Path to progress file :param callbacks: TrainCallback objects whose methods will be invoked at the appropriate times :param saver_max_to_keep: max_to_keep argument for tf.Saver :param model_path: Path from where to restore model and where to save models :return: """ self.model_path = model_path if callbacks is None: callbacks = [] global_step = make_count_variable('global_step') i_epoch = make_count_variable('i_epoch') i_batch = make_count_variable('i_batch') increment_global_step = make_count_incrementer(global_step) increment_epoch = make_count_incrementer(i_epoch) increment_batch = make_count_incrementer(i_batch) reset_batch_counter = make_count_resetter(i_batch) init_var_op = tf.global_variables_initializer() self.saver = tf.train.Saver(max_to_keep=saver_max_to_keep) try: self.saver.restore(self.sess, model_path) self.logger.info('Restored model from path {}'.format(model_path)) self.logger.info('Resuming asi from ' 'i_epoch={}, ' 'i_batch={}'.format( self.sess.run(i_epoch), self.sess.run(i_batch), )) except tf.errors.NotFoundError: self.logger.info('No model to restore at path {}'.format(model_path)) self.sess.run(init_var_op) self.logger.info('Initialized fresh model.') with self.sess.as_default(): self.sess.graph.finalize() try: while i_epoch.eval() < n_epochs_max: self.logger.info('===============================================') self.logger.info('Train i_epoch={}'.format(i_epoch.eval())) self.logger.info('===============================================') training_durations = [] all_train_metrics = [] self.sess.run(train_iterator_init_op) # Skip values if i_batch > 0 (i.e. we are resuming from checkpoint) n_skip = i_batch.eval() for i_skip in range(i_batch.eval()): self.sess.run([predictor.trajectory_lengths]) if n_skip > 0: self.logger.info('Done skipping to batch {}. ' 'Resuming asi now.'.format(n_skip)) while True: try: tic = time.time() predictor.train_on_trajectories(self.sess) toc = time.time() # self.logger.info('Batch duration: {:.4f}'.format(toc - tic)) all_train_metrics.append(predictor.train_analytics) training_durations.append(toc - tic) self.sess.run(increment_global_step) self.sess.run(increment_batch) for callback in callbacks: callback.on_batch_end(trainer=self, predictor=predictor, i_batch=i_batch.eval()) except tf.errors.OutOfRangeError: # Epoch over self.logger.info('Epoch finished. ' 'i_batch={}'.format(i_batch.eval())) break self.sess.run(reset_batch_counter) i_finished_epoch = i_epoch.eval() for callback in callbacks: callback.on_epoch_end(trainer=self, predictor=predictor, i_epoch=i_finished_epoch, ) self.sess.run(increment_epoch) self.save_model() self.logger.info('Evaluating accuracy on validation data ...') batch_sizes = [] all_valid_metrics = [] all_valid_jump_timesteps = [] all_valid_z_step_losses = [] all_effective_z_hat_lengths = [m['effective_z_hat_lengths'] for m in all_train_metrics] new_max_depth = int(round( np.percentile(all_effective_z_hat_lengths, 99.))) + 1 self.logger.info('new_max_depth: {}'.format(new_max_depth)) all_valid_metrics.append( {'prediction_max_depth': new_max_depth}) predictor.prediction_max_depth = new_max_depth self.sess.run(valid_iterator_init_op) while True: try: validation_step(predictor, batch_sizes, all_valid_metrics, all_valid_jump_timesteps, all_valid_z_step_losses, task_specific_metrics, self.sess) if len(all_valid_metrics) % 100 == 0: self.logger.info('validation_step {} completed'.format( len(all_valid_metrics))) except tf.errors.OutOfRangeError: # Exception indicates that loop over tfrecord is completed break self.logger.info('Done evaluating accuracy on validation data.') all_train_metrics.append( {'batch_training_time': float(np.mean(training_durations)) } ) if i_epoch.eval() % 4 == 0: write_step_data(all_valid_jump_timesteps, all_valid_z_step_losses, (step_data_filepath + '_ep_{}.txt'.format(i_epoch.eval()))) write_metrics(global_step.eval(), i_finished_epoch, all_train_metrics, all_valid_metrics, training_durations, progress_filepath) for callback in callbacks: callback.on_validation_end(trainer=self, predictor=predictor, train_metrics=all_train_metrics, valid_metrics=all_valid_metrics, i_epoch=i_finished_epoch) except KeyboardInterrupt: self.logger.info('Training interrupted') self.save_model() for callback in callbacks: callback.on_keyboard_interrupt(self) for callback in callbacks: callback.on_training_end(self) return self.sess