def test_predict_no_errors(self): """ Predictor with delta_t_bounds (1, 1) - should be equivalent to the version without dynamic time matching. Assertions to ensure that behavior stays the same (these are just the outputs of the predictor from 2018-03-22. """ tf.reset_default_graph() rng = np.random.RandomState(1234) x_shape = (5, 3) x_1 = rng.uniform(0, 1, (14, ) + x_shape) train_step_counter = make_count_variable('train_step_counter', 0) predictor = ASIModel( x_shape, f_optimizer=tf.train.AdamOptimizer(), f=make_example_f(list(x_shape)), delta_t_bounds=(1, 1), exploration_schedule=None, # unused schd_sampling_schedule=None, # unused parallel_iterations=10, train_step_counter=train_step_counter, ) init_vars = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_vars) z_hat = predictor.predict_n_steps(x_1, n=4, sess=sess) print('z_hat.shape', z_hat.shape) assert z_hat.shape == (14, 4, 5, 3) print('z_hat.mean()', z_hat.mean()) assert np.isclose(z_hat.mean(), 0.4999843)
def validation_step(predictor: ASIModel, batch_sizes, all_valid_metrics, all_valid_jump_timesteps, all_valid_z_step_losses, task_specific_metrics, sess): (frames_valid, trajectory_lengths) = sess.run([predictor.x, predictor.trajectory_lengths]) batch_sizes.append(len(frames_valid)) valid_metrics = predictor.analyze_batch(sess, x_padded=frames_valid, trajectory_lengths=trajectory_lengths, fetch_jump_steps=True, fetch_z_step_losses=True, fetch_z_hat=True, ) # Clean up unwanted output from metrics all_valid_jump_timesteps.append(valid_metrics['jump_timesteps']) all_valid_z_step_losses.append(valid_metrics['z_step_losses']) # ### Get predicted batch ### predicted_batch = [] for trajectory in frames_valid: predicted_frames = predictor.predict_n_steps(trajectory[0][np.newaxis], predictor.prediction_max_depth, sess)[0] predicted_batch.append(predicted_frames) predicted_batch = np.asarray(predicted_batch) # ########################### for key, metric_fn in task_specific_metrics.items(): valid_metrics[key + '_matched'] = metric_fn(frames_valid, valid_metrics['z_hat'], trajectory_lengths) valid_metrics[key] = metric_fn(frames_valid, predicted_batch, trajectory_lengths) del valid_metrics['z_hat'] del valid_metrics['jump_timesteps'] del valid_metrics['z_step_losses'] all_valid_metrics.append(valid_metrics)
def make_iterative_predictor(cmd_args, x_shape, x, trajectory_lengths): train_step_counter = make_count_variable('train_step_counter', init_value=0) f_learning_rate = tf.train.exponential_decay( cmd_args.f_init_learning_rate, global_step=train_step_counter, decay_steps=cmd_args.f_learning_rate_decay_steps, decay_rate=cmd_args.f_learning_rate_decay_rate, staircase=True) if cmd_args.optimizer == 'adam': f_optimizer = tf.train.AdamOptimizer(learning_rate=f_learning_rate) elif cmd_args.optimizer == 'sgd10': f_optimizer = tf.train.GradientDescentOptimizer(learning_rate=10 * f_learning_rate) else: raise ValueError('Unknown optimizer: {}'.format(cmd_args.optimizer)) f = get_f(cmd_args.data_format, cmd_args.f_architecture, x_shape, cmd_args) if cmd_args.z_loss_fn == 'log_loss': z_loss_fn = asi.helpers.z_loss_fns.log_loss else: raise ValueError('Unknown z_loss_fn {}'.format(cmd_args.z_loss_fn)) if cmd_args.data_format == 'channels_first': input_shape = (x_shape[2], x_shape[0], x_shape[1]) elif cmd_args.data_format == 'channels_last': input_shape = x_shape else: raise ValueError('Illegal data format') additional_metrics = { 'f_learning_rate': f_learning_rate, } schd_sampling_schedule = annealing_schedules.get_reciprocal( decay=1. / (0.2 * cmd_args.schd_sampling_steps)) exploration_schedule = annealing_schedules.get_linear( cmd_args.exploration_steps, 1.0, 0) latent_predictor = ASIModel( x_shape=input_shape, f_optimizer=f_optimizer, f=f, delta_t_bounds=(cmd_args.delta_t_lower_bound, cmd_args.delta_t_upper_bound), exploration_schedule=exploration_schedule, schd_sampling_schedule=schd_sampling_schedule, additional_metrics=additional_metrics, x=x, trajectory_lengths=trajectory_lengths, parallel_iterations=10, train_step_counter=train_step_counter, z_loss_fn=z_loss_fn, ) return latent_predictor
def test_predict_simple_models(self): tf.reset_default_graph() x_shape = (2, 3) z_shape = x_shape n_timesteps = 4 # f adds one z_1 = Input(shape=z_shape) f = Model(z_1, Lambda(lambda x: x + 1)(z_1), name='f') train_step_counter = make_count_variable('train_step_counter', 0) predictor = ASIModel( x_shape, f_optimizer=tf.train.AdamOptimizer(), f=f, delta_t_bounds=(1, 1), exploration_schedule=annealing_schedules.constant_zero, schd_sampling_schedule=annealing_schedules.constant_zero, parallel_iterations=10, train_step_counter=train_step_counter, forgiving=True) x_bias = 10.0 x_1 = np.stack((np.ones(x_shape, dtype=np.float32), x_bias + np.ones(x_shape, dtype=np.float32)), axis=0) print('x_1.shape', x_1.shape) init_vars = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_vars) _z_hat = predictor.predict_n_steps(x_1, n=n_timesteps, sess=sess) desired_z_hat_a = (np.arange(2, 2 + n_timesteps)[:, None, None] + np.zeros((n_timesteps, ) + x_shape)) desired_z_hat_b = (x_bias + np.arange(2, 2 + n_timesteps)[:, None, None] + np.zeros((n_timesteps, ) + x_shape)) assert _z_hat.tolist() == [ desired_z_hat_a.tolist(), desired_z_hat_b.tolist() ]
def test_train_multiple_trajectories_no_errors(self): tf.reset_default_graph() rng = np.random.RandomState(1234) x_shape = (2, 3) trajectory_lengths = [4, 3, 5, 2] x_list = [ rng.uniform(0, 1, (trajectory_length, ) + x_shape) for trajectory_length in trajectory_lengths ] train_step_counter = make_count_variable('train_step_counter', 0) predictor = ASIModel( x_shape, f_optimizer=tf.train.AdamOptimizer(), f=make_example_f(list(x_shape)), delta_t_bounds=(1, 1), exploration_schedule=annealing_schedules.constant_zero, schd_sampling_schedule=annealing_schedules.constant_zero, parallel_iterations=10, train_step_counter=train_step_counter, z_loss_fn=z_loss_fns.symmetric_log_loss) init_vars = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_vars) predictor.train_on_trajectory_list(sess, x_list) _loss = get_batch_loss(predictor) print('_loss 1', _loss) assert np.isclose(_loss, 1.7272503) predictor.train_on_trajectory_list(sess, x_list) _loss = get_batch_loss(predictor) print('_loss 2', _loss) assert np.isclose(_loss, 1.7265959)
def test_train_f(self): """ The real dynamics is: x -> x+1 Parametrize model f with one parameter theta, to encode the function x -> x + theta Check if the loss can be reduced """ tf.reset_default_graph() x_shape = (1, ) z_shape = x_shape n_timesteps = 5 # f adds theta z_1 = Input(shape=z_shape) with tf.variable_scope('test_train_f'): theta = tf.get_variable('theta_2', shape=(), dtype=tf.float32, initializer=tf.initializers.constant(0.0)) f_layer = Lambda( lambda x: x + theta, trainable=True, ) f_layer.trainable_weights = [theta] f = Model(z_1, f_layer(z_1), name='f') print('f.trainable_weights', f.trainable_weights) learning_rate = 0.01 train_step_counter = make_count_variable('train_step_counter', 0) predictor = ASIModel( x_shape, f=f, delta_t_bounds=(1, 1), exploration_schedule=annealing_schedules.constant_zero, schd_sampling_schedule=annealing_schedules.constant_zero, parallel_iterations=10, train_step_counter=train_step_counter, z_loss_fn=lambda a, b: tf.square(a - b), f_optimizer=tf.train.GradientDescentOptimizer(learning_rate), forgiving=True) x_observed = (np.zeros( (n_timesteps, ) + x_shape, dtype=np.float32) + np.arange( 0, n_timesteps).reshape((1, -1) + tuple([1] * len(x_shape)))) x_1 = x_observed[:, 0] batch_size = x_observed.shape[0] init_vars = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_vars) _z_hat = predictor.predict_n_steps(x_1, n=n_timesteps, sess=sess) print('_z_hat:', _z_hat) desired_z_hat = np.zeros((batch_size, n_timesteps, 1)) assert _z_hat.tolist() == desired_z_hat.tolist() predictor.train_on_trajectories(sess, x_observed, trajectory_lengths=[n_timesteps]) assert predictor.train_analytics['z_loss'] == 7.5 print('predictor.train_analytics', predictor.train_analytics) new_theta = sess.run(theta) print('new_theta', new_theta) assert np.isclose(new_theta, 0.15)
def run(self, train_iterator_init_op, valid_iterator_init_op, predictor: ASIModel, n_epochs_max, task_specific_metrics, progress_filepath, step_data_filepath, callbacks: List[TrainCallback] = None, saver_max_to_keep=5, model_path=None): """ :param train_iterator_init_op: tf-operation which initializes the train-iterator :param valid_iterator_init_op: tf-operation which initializes the valid-iterator :param predictor: Instance of ASIModel :param progress_filepath: Path to progress file :param callbacks: TrainCallback objects whose methods will be invoked at the appropriate times :param saver_max_to_keep: max_to_keep argument for tf.Saver :param model_path: Path from where to restore model and where to save models :return: """ self.model_path = model_path if callbacks is None: callbacks = [] global_step = make_count_variable('global_step') i_epoch = make_count_variable('i_epoch') i_batch = make_count_variable('i_batch') increment_global_step = make_count_incrementer(global_step) increment_epoch = make_count_incrementer(i_epoch) increment_batch = make_count_incrementer(i_batch) reset_batch_counter = make_count_resetter(i_batch) init_var_op = tf.global_variables_initializer() self.saver = tf.train.Saver(max_to_keep=saver_max_to_keep) try: self.saver.restore(self.sess, model_path) self.logger.info('Restored model from path {}'.format(model_path)) self.logger.info('Resuming asi from ' 'i_epoch={}, ' 'i_batch={}'.format( self.sess.run(i_epoch), self.sess.run(i_batch), )) except tf.errors.NotFoundError: self.logger.info('No model to restore at path {}'.format(model_path)) self.sess.run(init_var_op) self.logger.info('Initialized fresh model.') with self.sess.as_default(): self.sess.graph.finalize() try: while i_epoch.eval() < n_epochs_max: self.logger.info('===============================================') self.logger.info('Train i_epoch={}'.format(i_epoch.eval())) self.logger.info('===============================================') training_durations = [] all_train_metrics = [] self.sess.run(train_iterator_init_op) # Skip values if i_batch > 0 (i.e. we are resuming from checkpoint) n_skip = i_batch.eval() for i_skip in range(i_batch.eval()): self.sess.run([predictor.trajectory_lengths]) if n_skip > 0: self.logger.info('Done skipping to batch {}. ' 'Resuming asi now.'.format(n_skip)) while True: try: tic = time.time() predictor.train_on_trajectories(self.sess) toc = time.time() # self.logger.info('Batch duration: {:.4f}'.format(toc - tic)) all_train_metrics.append(predictor.train_analytics) training_durations.append(toc - tic) self.sess.run(increment_global_step) self.sess.run(increment_batch) for callback in callbacks: callback.on_batch_end(trainer=self, predictor=predictor, i_batch=i_batch.eval()) except tf.errors.OutOfRangeError: # Epoch over self.logger.info('Epoch finished. ' 'i_batch={}'.format(i_batch.eval())) break self.sess.run(reset_batch_counter) i_finished_epoch = i_epoch.eval() for callback in callbacks: callback.on_epoch_end(trainer=self, predictor=predictor, i_epoch=i_finished_epoch, ) self.sess.run(increment_epoch) self.save_model() self.logger.info('Evaluating accuracy on validation data ...') batch_sizes = [] all_valid_metrics = [] all_valid_jump_timesteps = [] all_valid_z_step_losses = [] all_effective_z_hat_lengths = [m['effective_z_hat_lengths'] for m in all_train_metrics] new_max_depth = int(round( np.percentile(all_effective_z_hat_lengths, 99.))) + 1 self.logger.info('new_max_depth: {}'.format(new_max_depth)) all_valid_metrics.append( {'prediction_max_depth': new_max_depth}) predictor.prediction_max_depth = new_max_depth self.sess.run(valid_iterator_init_op) while True: try: validation_step(predictor, batch_sizes, all_valid_metrics, all_valid_jump_timesteps, all_valid_z_step_losses, task_specific_metrics, self.sess) if len(all_valid_metrics) % 100 == 0: self.logger.info('validation_step {} completed'.format( len(all_valid_metrics))) except tf.errors.OutOfRangeError: # Exception indicates that loop over tfrecord is completed break self.logger.info('Done evaluating accuracy on validation data.') all_train_metrics.append( {'batch_training_time': float(np.mean(training_durations)) } ) if i_epoch.eval() % 4 == 0: write_step_data(all_valid_jump_timesteps, all_valid_z_step_losses, (step_data_filepath + '_ep_{}.txt'.format(i_epoch.eval()))) write_metrics(global_step.eval(), i_finished_epoch, all_train_metrics, all_valid_metrics, training_durations, progress_filepath) for callback in callbacks: callback.on_validation_end(trainer=self, predictor=predictor, train_metrics=all_train_metrics, valid_metrics=all_valid_metrics, i_epoch=i_finished_epoch) except KeyboardInterrupt: self.logger.info('Training interrupted') self.save_model() for callback in callbacks: callback.on_keyboard_interrupt(self) for callback in callbacks: callback.on_training_end(self) return self.sess