예제 #1
0
    def test_train_f(self):
        """
        The real dynamics is: x -> x+1
        Parametrize model f with one parameter theta, to encode the function x -> x + theta
        Check if the loss can be reduced
        """
        tf.reset_default_graph()
        x_shape = (1, )
        z_shape = x_shape
        n_timesteps = 5

        # f adds theta
        z_1 = Input(shape=z_shape)
        with tf.variable_scope('test_train_f'):
            theta = tf.get_variable('theta_2',
                                    shape=(),
                                    dtype=tf.float32,
                                    initializer=tf.initializers.constant(0.0))
        f_layer = Lambda(
            lambda x: x + theta,
            trainable=True,
        )
        f_layer.trainable_weights = [theta]
        f = Model(z_1, f_layer(z_1), name='f')
        print('f.trainable_weights', f.trainable_weights)

        learning_rate = 0.01
        train_step_counter = make_count_variable('train_step_counter', 0)
        predictor = ASIModel(
            x_shape,
            f=f,
            delta_t_bounds=(1, 1),
            exploration_schedule=annealing_schedules.constant_zero,
            schd_sampling_schedule=annealing_schedules.constant_zero,
            parallel_iterations=10,
            train_step_counter=train_step_counter,
            z_loss_fn=lambda a, b: tf.square(a - b),
            f_optimizer=tf.train.GradientDescentOptimizer(learning_rate),
            forgiving=True)

        x_observed = (np.zeros(
            (n_timesteps, ) + x_shape, dtype=np.float32) + np.arange(
                0, n_timesteps).reshape((1, -1) + tuple([1] * len(x_shape))))
        x_1 = x_observed[:, 0]
        batch_size = x_observed.shape[0]

        init_vars = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init_vars)
            _z_hat = predictor.predict_n_steps(x_1, n=n_timesteps, sess=sess)
            print('_z_hat:', _z_hat)
            desired_z_hat = np.zeros((batch_size, n_timesteps, 1))
            assert _z_hat.tolist() == desired_z_hat.tolist()

            predictor.train_on_trajectories(sess,
                                            x_observed,
                                            trajectory_lengths=[n_timesteps])
            assert predictor.train_analytics['z_loss'] == 7.5

            print('predictor.train_analytics', predictor.train_analytics)

            new_theta = sess.run(theta)
            print('new_theta', new_theta)
            assert np.isclose(new_theta, 0.15)
예제 #2
0
    def run(self,
            train_iterator_init_op,
            valid_iterator_init_op,
            predictor: ASIModel,
            n_epochs_max,
            task_specific_metrics,
            progress_filepath,
            step_data_filepath,
            callbacks: List[TrainCallback] = None,
            saver_max_to_keep=5,
            model_path=None):
        """
        :param train_iterator_init_op: tf-operation which initializes the train-iterator
        :param valid_iterator_init_op: tf-operation which initializes the valid-iterator
        :param predictor: Instance of ASIModel
        :param progress_filepath: Path to progress file
        :param callbacks: TrainCallback objects whose methods will be invoked at the
                          appropriate times
        :param saver_max_to_keep: max_to_keep argument for tf.Saver
        :param model_path: Path from where to restore model and where to save models
        :return:
        """

        self.model_path = model_path

        if callbacks is None:
            callbacks = []

        global_step = make_count_variable('global_step')
        i_epoch = make_count_variable('i_epoch')
        i_batch = make_count_variable('i_batch')

        increment_global_step = make_count_incrementer(global_step)
        increment_epoch = make_count_incrementer(i_epoch)
        increment_batch = make_count_incrementer(i_batch)

        reset_batch_counter = make_count_resetter(i_batch)

        init_var_op = tf.global_variables_initializer()

        self.saver = tf.train.Saver(max_to_keep=saver_max_to_keep)
        try:
            self.saver.restore(self.sess, model_path)
            self.logger.info('Restored model from path {}'.format(model_path))
            self.logger.info('Resuming asi from '
                             'i_epoch={}, '
                             'i_batch={}'.format(
                self.sess.run(i_epoch),
                self.sess.run(i_batch),
            ))
        except tf.errors.NotFoundError:
            self.logger.info('No model to restore at path {}'.format(model_path))
            self.sess.run(init_var_op)
            self.logger.info('Initialized fresh model.')

        with self.sess.as_default():
            self.sess.graph.finalize()

            try:

                while i_epoch.eval() < n_epochs_max:

                    self.logger.info('===============================================')
                    self.logger.info('Train i_epoch={}'.format(i_epoch.eval()))
                    self.logger.info('===============================================')
                    training_durations = []
                    all_train_metrics = []

                    self.sess.run(train_iterator_init_op)

                    # Skip values if i_batch > 0 (i.e. we are resuming from checkpoint)
                    n_skip = i_batch.eval()
                    for i_skip in range(i_batch.eval()):
                        self.sess.run([predictor.trajectory_lengths])
                    if n_skip > 0:
                        self.logger.info('Done skipping to batch {}. '
                                         'Resuming asi now.'.format(n_skip))

                    while True:
                        try:
                            tic = time.time()
                            predictor.train_on_trajectories(self.sess)
                            toc = time.time()

                            # self.logger.info('Batch duration: {:.4f}'.format(toc - tic))

                            all_train_metrics.append(predictor.train_analytics)
                            training_durations.append(toc - tic)
                            self.sess.run(increment_global_step)
                            self.sess.run(increment_batch)

                            for callback in callbacks:
                                callback.on_batch_end(trainer=self,
                                                      predictor=predictor,
                                                      i_batch=i_batch.eval())
                        except tf.errors.OutOfRangeError:  # Epoch over
                            self.logger.info('Epoch finished. '
                                             'i_batch={}'.format(i_batch.eval()))
                            break

                    self.sess.run(reset_batch_counter)

                    i_finished_epoch = i_epoch.eval()
                    for callback in callbacks:
                        callback.on_epoch_end(trainer=self,
                                              predictor=predictor,
                                              i_epoch=i_finished_epoch,
                                              )

                    self.sess.run(increment_epoch)

                    self.save_model()

                    self.logger.info('Evaluating accuracy on validation data ...')
                    batch_sizes = []
                    all_valid_metrics = []
                    all_valid_jump_timesteps = []
                    all_valid_z_step_losses = []

                    all_effective_z_hat_lengths = [m['effective_z_hat_lengths']
                                                   for m in all_train_metrics]
                    new_max_depth = int(round(
                        np.percentile(all_effective_z_hat_lengths, 99.))) + 1

                    self.logger.info('new_max_depth: {}'.format(new_max_depth))
                    all_valid_metrics.append(
                        {'prediction_max_depth': new_max_depth})
                    predictor.prediction_max_depth = new_max_depth

                    self.sess.run(valid_iterator_init_op)

                    while True:
                        try:
                            validation_step(predictor, batch_sizes, all_valid_metrics,
                                            all_valid_jump_timesteps,
                                            all_valid_z_step_losses,
                                            task_specific_metrics,
                                            self.sess)

                            if len(all_valid_metrics) % 100 == 0:
                                self.logger.info('validation_step {} completed'.format(
                                    len(all_valid_metrics)))

                        except tf.errors.OutOfRangeError:
                            # Exception indicates that loop over tfrecord is completed
                            break

                    self.logger.info('Done evaluating accuracy on validation data.')

                    all_train_metrics.append(
                        {'batch_training_time':
                             float(np.mean(training_durations))
                         } )

                    if i_epoch.eval() % 4 == 0:
                        write_step_data(all_valid_jump_timesteps,
                                        all_valid_z_step_losses,
                                        (step_data_filepath
                                         + '_ep_{}.txt'.format(i_epoch.eval())))

                    write_metrics(global_step.eval(),
                                  i_finished_epoch,
                                  all_train_metrics,
                                  all_valid_metrics,
                                  training_durations,
                                  progress_filepath)

                    for callback in callbacks:
                        callback.on_validation_end(trainer=self,
                                                   predictor=predictor,
                                                   train_metrics=all_train_metrics,
                                                   valid_metrics=all_valid_metrics,
                                                   i_epoch=i_finished_epoch)

            except KeyboardInterrupt:
                self.logger.info('Training interrupted')
                self.save_model()
                for callback in callbacks:
                    callback.on_keyboard_interrupt(self)

        for callback in callbacks:
            callback.on_training_end(self)

        return self.sess