예제 #1
0
    def test_predict_no_errors(self):
        """
        Predictor with delta_t_bounds (1, 1) - should be equivalent to the version
        without dynamic time matching.
        Assertions to ensure that behavior stays the same (these are just the outputs
        of the predictor from 2018-03-22.
        """
        tf.reset_default_graph()
        rng = np.random.RandomState(1234)
        x_shape = (5, 3)
        x_1 = rng.uniform(0, 1, (14, ) + x_shape)

        train_step_counter = make_count_variable('train_step_counter', 0)
        predictor = ASIModel(
            x_shape,
            f_optimizer=tf.train.AdamOptimizer(),
            f=make_example_f(list(x_shape)),
            delta_t_bounds=(1, 1),
            exploration_schedule=None,  # unused
            schd_sampling_schedule=None,  # unused
            parallel_iterations=10,
            train_step_counter=train_step_counter,
        )
        init_vars = tf.global_variables_initializer()

        with tf.Session() as sess:
            sess.run(init_vars)
            z_hat = predictor.predict_n_steps(x_1, n=4, sess=sess)

        print('z_hat.shape', z_hat.shape)
        assert z_hat.shape == (14, 4, 5, 3)

        print('z_hat.mean()', z_hat.mean())
        assert np.isclose(z_hat.mean(), 0.4999843)
예제 #2
0
def validation_step(predictor: ASIModel, batch_sizes, all_valid_metrics,
                    all_valid_jump_timesteps, all_valid_z_step_losses,
                    task_specific_metrics,
                    sess):
    (frames_valid,
     trajectory_lengths) = sess.run([predictor.x,
                                     predictor.trajectory_lengths])

    batch_sizes.append(len(frames_valid))

    valid_metrics = predictor.analyze_batch(sess,
                                            x_padded=frames_valid,
                                            trajectory_lengths=trajectory_lengths,
                                            fetch_jump_steps=True,
                                            fetch_z_step_losses=True,
                                            fetch_z_hat=True,
                                            )

    # Clean up unwanted output from metrics
    all_valid_jump_timesteps.append(valid_metrics['jump_timesteps'])
    all_valid_z_step_losses.append(valid_metrics['z_step_losses'])

    # ### Get predicted batch ###
    predicted_batch = []
    for trajectory in frames_valid:
        predicted_frames = predictor.predict_n_steps(trajectory[0][np.newaxis],
                                                     predictor.prediction_max_depth,
                                                     sess)[0]
        predicted_batch.append(predicted_frames)
    predicted_batch = np.asarray(predicted_batch)
    # ###########################

    for key, metric_fn in task_specific_metrics.items():
        valid_metrics[key + '_matched'] = metric_fn(frames_valid,
                                                    valid_metrics['z_hat'],
                                                    trajectory_lengths)

        valid_metrics[key] = metric_fn(frames_valid,
                                       predicted_batch,
                                       trajectory_lengths)

    del valid_metrics['z_hat']
    del valid_metrics['jump_timesteps']
    del valid_metrics['z_step_losses']

    all_valid_metrics.append(valid_metrics)
예제 #3
0
def make_iterative_predictor(cmd_args, x_shape, x, trajectory_lengths):
    train_step_counter = make_count_variable('train_step_counter',
                                             init_value=0)

    f_learning_rate = tf.train.exponential_decay(
        cmd_args.f_init_learning_rate,
        global_step=train_step_counter,
        decay_steps=cmd_args.f_learning_rate_decay_steps,
        decay_rate=cmd_args.f_learning_rate_decay_rate,
        staircase=True)

    if cmd_args.optimizer == 'adam':
        f_optimizer = tf.train.AdamOptimizer(learning_rate=f_learning_rate)
    elif cmd_args.optimizer == 'sgd10':
        f_optimizer = tf.train.GradientDescentOptimizer(learning_rate=10 *
                                                        f_learning_rate)

    else:
        raise ValueError('Unknown optimizer: {}'.format(cmd_args.optimizer))

    f = get_f(cmd_args.data_format, cmd_args.f_architecture, x_shape, cmd_args)

    if cmd_args.z_loss_fn == 'log_loss':
        z_loss_fn = asi.helpers.z_loss_fns.log_loss
    else:
        raise ValueError('Unknown z_loss_fn {}'.format(cmd_args.z_loss_fn))

    if cmd_args.data_format == 'channels_first':
        input_shape = (x_shape[2], x_shape[0], x_shape[1])
    elif cmd_args.data_format == 'channels_last':
        input_shape = x_shape
    else:
        raise ValueError('Illegal data format')

    additional_metrics = {
        'f_learning_rate': f_learning_rate,
    }
    schd_sampling_schedule = annealing_schedules.get_reciprocal(
        decay=1. / (0.2 * cmd_args.schd_sampling_steps))
    exploration_schedule = annealing_schedules.get_linear(
        cmd_args.exploration_steps, 1.0, 0)
    latent_predictor = ASIModel(
        x_shape=input_shape,
        f_optimizer=f_optimizer,
        f=f,
        delta_t_bounds=(cmd_args.delta_t_lower_bound,
                        cmd_args.delta_t_upper_bound),
        exploration_schedule=exploration_schedule,
        schd_sampling_schedule=schd_sampling_schedule,
        additional_metrics=additional_metrics,
        x=x,
        trajectory_lengths=trajectory_lengths,
        parallel_iterations=10,
        train_step_counter=train_step_counter,
        z_loss_fn=z_loss_fn,
    )
    return latent_predictor
예제 #4
0
    def test_predict_simple_models(self):
        tf.reset_default_graph()
        x_shape = (2, 3)
        z_shape = x_shape
        n_timesteps = 4

        # f adds one
        z_1 = Input(shape=z_shape)
        f = Model(z_1, Lambda(lambda x: x + 1)(z_1), name='f')

        train_step_counter = make_count_variable('train_step_counter', 0)
        predictor = ASIModel(
            x_shape,
            f_optimizer=tf.train.AdamOptimizer(),
            f=f,
            delta_t_bounds=(1, 1),
            exploration_schedule=annealing_schedules.constant_zero,
            schd_sampling_schedule=annealing_schedules.constant_zero,
            parallel_iterations=10,
            train_step_counter=train_step_counter,
            forgiving=True)

        x_bias = 10.0
        x_1 = np.stack((np.ones(x_shape, dtype=np.float32),
                        x_bias + np.ones(x_shape, dtype=np.float32)),
                       axis=0)
        print('x_1.shape', x_1.shape)
        init_vars = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init_vars)
            _z_hat = predictor.predict_n_steps(x_1, n=n_timesteps, sess=sess)
            desired_z_hat_a = (np.arange(2, 2 + n_timesteps)[:, None, None] +
                               np.zeros((n_timesteps, ) + x_shape))
            desired_z_hat_b = (x_bias +
                               np.arange(2, 2 + n_timesteps)[:, None, None] +
                               np.zeros((n_timesteps, ) + x_shape))

            assert _z_hat.tolist() == [
                desired_z_hat_a.tolist(),
                desired_z_hat_b.tolist()
            ]
예제 #5
0
    def test_train_multiple_trajectories_no_errors(self):
        tf.reset_default_graph()
        rng = np.random.RandomState(1234)

        x_shape = (2, 3)

        trajectory_lengths = [4, 3, 5, 2]
        x_list = [
            rng.uniform(0, 1, (trajectory_length, ) + x_shape)
            for trajectory_length in trajectory_lengths
        ]

        train_step_counter = make_count_variable('train_step_counter', 0)
        predictor = ASIModel(
            x_shape,
            f_optimizer=tf.train.AdamOptimizer(),
            f=make_example_f(list(x_shape)),
            delta_t_bounds=(1, 1),
            exploration_schedule=annealing_schedules.constant_zero,
            schd_sampling_schedule=annealing_schedules.constant_zero,
            parallel_iterations=10,
            train_step_counter=train_step_counter,
            z_loss_fn=z_loss_fns.symmetric_log_loss)
        init_vars = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init_vars)
            predictor.train_on_trajectory_list(sess, x_list)
            _loss = get_batch_loss(predictor)
            print('_loss 1', _loss)
            assert np.isclose(_loss, 1.7272503)

            predictor.train_on_trajectory_list(sess, x_list)
            _loss = get_batch_loss(predictor)
            print('_loss 2', _loss)
            assert np.isclose(_loss, 1.7265959)
예제 #6
0
    def test_train_f(self):
        """
        The real dynamics is: x -> x+1
        Parametrize model f with one parameter theta, to encode the function x -> x + theta
        Check if the loss can be reduced
        """
        tf.reset_default_graph()
        x_shape = (1, )
        z_shape = x_shape
        n_timesteps = 5

        # f adds theta
        z_1 = Input(shape=z_shape)
        with tf.variable_scope('test_train_f'):
            theta = tf.get_variable('theta_2',
                                    shape=(),
                                    dtype=tf.float32,
                                    initializer=tf.initializers.constant(0.0))
        f_layer = Lambda(
            lambda x: x + theta,
            trainable=True,
        )
        f_layer.trainable_weights = [theta]
        f = Model(z_1, f_layer(z_1), name='f')
        print('f.trainable_weights', f.trainable_weights)

        learning_rate = 0.01
        train_step_counter = make_count_variable('train_step_counter', 0)
        predictor = ASIModel(
            x_shape,
            f=f,
            delta_t_bounds=(1, 1),
            exploration_schedule=annealing_schedules.constant_zero,
            schd_sampling_schedule=annealing_schedules.constant_zero,
            parallel_iterations=10,
            train_step_counter=train_step_counter,
            z_loss_fn=lambda a, b: tf.square(a - b),
            f_optimizer=tf.train.GradientDescentOptimizer(learning_rate),
            forgiving=True)

        x_observed = (np.zeros(
            (n_timesteps, ) + x_shape, dtype=np.float32) + np.arange(
                0, n_timesteps).reshape((1, -1) + tuple([1] * len(x_shape))))
        x_1 = x_observed[:, 0]
        batch_size = x_observed.shape[0]

        init_vars = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init_vars)
            _z_hat = predictor.predict_n_steps(x_1, n=n_timesteps, sess=sess)
            print('_z_hat:', _z_hat)
            desired_z_hat = np.zeros((batch_size, n_timesteps, 1))
            assert _z_hat.tolist() == desired_z_hat.tolist()

            predictor.train_on_trajectories(sess,
                                            x_observed,
                                            trajectory_lengths=[n_timesteps])
            assert predictor.train_analytics['z_loss'] == 7.5

            print('predictor.train_analytics', predictor.train_analytics)

            new_theta = sess.run(theta)
            print('new_theta', new_theta)
            assert np.isclose(new_theta, 0.15)
예제 #7
0
    def run(self,
            train_iterator_init_op,
            valid_iterator_init_op,
            predictor: ASIModel,
            n_epochs_max,
            task_specific_metrics,
            progress_filepath,
            step_data_filepath,
            callbacks: List[TrainCallback] = None,
            saver_max_to_keep=5,
            model_path=None):
        """
        :param train_iterator_init_op: tf-operation which initializes the train-iterator
        :param valid_iterator_init_op: tf-operation which initializes the valid-iterator
        :param predictor: Instance of ASIModel
        :param progress_filepath: Path to progress file
        :param callbacks: TrainCallback objects whose methods will be invoked at the
                          appropriate times
        :param saver_max_to_keep: max_to_keep argument for tf.Saver
        :param model_path: Path from where to restore model and where to save models
        :return:
        """

        self.model_path = model_path

        if callbacks is None:
            callbacks = []

        global_step = make_count_variable('global_step')
        i_epoch = make_count_variable('i_epoch')
        i_batch = make_count_variable('i_batch')

        increment_global_step = make_count_incrementer(global_step)
        increment_epoch = make_count_incrementer(i_epoch)
        increment_batch = make_count_incrementer(i_batch)

        reset_batch_counter = make_count_resetter(i_batch)

        init_var_op = tf.global_variables_initializer()

        self.saver = tf.train.Saver(max_to_keep=saver_max_to_keep)
        try:
            self.saver.restore(self.sess, model_path)
            self.logger.info('Restored model from path {}'.format(model_path))
            self.logger.info('Resuming asi from '
                             'i_epoch={}, '
                             'i_batch={}'.format(
                self.sess.run(i_epoch),
                self.sess.run(i_batch),
            ))
        except tf.errors.NotFoundError:
            self.logger.info('No model to restore at path {}'.format(model_path))
            self.sess.run(init_var_op)
            self.logger.info('Initialized fresh model.')

        with self.sess.as_default():
            self.sess.graph.finalize()

            try:

                while i_epoch.eval() < n_epochs_max:

                    self.logger.info('===============================================')
                    self.logger.info('Train i_epoch={}'.format(i_epoch.eval()))
                    self.logger.info('===============================================')
                    training_durations = []
                    all_train_metrics = []

                    self.sess.run(train_iterator_init_op)

                    # Skip values if i_batch > 0 (i.e. we are resuming from checkpoint)
                    n_skip = i_batch.eval()
                    for i_skip in range(i_batch.eval()):
                        self.sess.run([predictor.trajectory_lengths])
                    if n_skip > 0:
                        self.logger.info('Done skipping to batch {}. '
                                         'Resuming asi now.'.format(n_skip))

                    while True:
                        try:
                            tic = time.time()
                            predictor.train_on_trajectories(self.sess)
                            toc = time.time()

                            # self.logger.info('Batch duration: {:.4f}'.format(toc - tic))

                            all_train_metrics.append(predictor.train_analytics)
                            training_durations.append(toc - tic)
                            self.sess.run(increment_global_step)
                            self.sess.run(increment_batch)

                            for callback in callbacks:
                                callback.on_batch_end(trainer=self,
                                                      predictor=predictor,
                                                      i_batch=i_batch.eval())
                        except tf.errors.OutOfRangeError:  # Epoch over
                            self.logger.info('Epoch finished. '
                                             'i_batch={}'.format(i_batch.eval()))
                            break

                    self.sess.run(reset_batch_counter)

                    i_finished_epoch = i_epoch.eval()
                    for callback in callbacks:
                        callback.on_epoch_end(trainer=self,
                                              predictor=predictor,
                                              i_epoch=i_finished_epoch,
                                              )

                    self.sess.run(increment_epoch)

                    self.save_model()

                    self.logger.info('Evaluating accuracy on validation data ...')
                    batch_sizes = []
                    all_valid_metrics = []
                    all_valid_jump_timesteps = []
                    all_valid_z_step_losses = []

                    all_effective_z_hat_lengths = [m['effective_z_hat_lengths']
                                                   for m in all_train_metrics]
                    new_max_depth = int(round(
                        np.percentile(all_effective_z_hat_lengths, 99.))) + 1

                    self.logger.info('new_max_depth: {}'.format(new_max_depth))
                    all_valid_metrics.append(
                        {'prediction_max_depth': new_max_depth})
                    predictor.prediction_max_depth = new_max_depth

                    self.sess.run(valid_iterator_init_op)

                    while True:
                        try:
                            validation_step(predictor, batch_sizes, all_valid_metrics,
                                            all_valid_jump_timesteps,
                                            all_valid_z_step_losses,
                                            task_specific_metrics,
                                            self.sess)

                            if len(all_valid_metrics) % 100 == 0:
                                self.logger.info('validation_step {} completed'.format(
                                    len(all_valid_metrics)))

                        except tf.errors.OutOfRangeError:
                            # Exception indicates that loop over tfrecord is completed
                            break

                    self.logger.info('Done evaluating accuracy on validation data.')

                    all_train_metrics.append(
                        {'batch_training_time':
                             float(np.mean(training_durations))
                         } )

                    if i_epoch.eval() % 4 == 0:
                        write_step_data(all_valid_jump_timesteps,
                                        all_valid_z_step_losses,
                                        (step_data_filepath
                                         + '_ep_{}.txt'.format(i_epoch.eval())))

                    write_metrics(global_step.eval(),
                                  i_finished_epoch,
                                  all_train_metrics,
                                  all_valid_metrics,
                                  training_durations,
                                  progress_filepath)

                    for callback in callbacks:
                        callback.on_validation_end(trainer=self,
                                                   predictor=predictor,
                                                   train_metrics=all_train_metrics,
                                                   valid_metrics=all_valid_metrics,
                                                   i_epoch=i_finished_epoch)

            except KeyboardInterrupt:
                self.logger.info('Training interrupted')
                self.save_model()
                for callback in callbacks:
                    callback.on_keyboard_interrupt(self)

        for callback in callbacks:
            callback.on_training_end(self)

        return self.sess