def build_model_components(self):
    self.batch_size = 11
    self.seq_len = 17
    self.obs_dim = 3
    self.hidden_dim = 8
    self.num_categ = 5
    self.config_emission = self.get_default_distribution_config()
    self.config_inference = self.get_default_distribution_config()
    self.config_z_initial = self.get_default_distribution_config()
    self.config_z_transition = self.get_default_distribution_config()

    self.network_z_transition = [
        utils.build_dense_network(
            [3*self.hidden_dim, self.hidden_dim], ["relu", None])
        for _ in range(self.num_categ)]
    self.z_trans_dist = model_cavi_snlds.ContinuousStateTransition(
        self.network_z_transition,
        distribution_dim=self.hidden_dim,
        num_categories=self.num_categ,
        **self.config_z_transition)

    num_categ_squared = self.num_categ * self.num_categ
    self.network_s_transition = utils.build_dense_network(
        [4 * num_categ_squared, num_categ_squared],
        ["relu", None])
    self.s_trans = model_cavi_snlds.DiscreteStateTransition(
        transition_network=self.network_s_transition,
        num_categories=self.num_categ)

    self.network_emission = utils.build_dense_network(
        [4 * self.obs_dim, self.obs_dim],
        ["relu", None])
    self.x_dist = model_cavi_snlds.GaussianDistributionFromMean(
        emission_mean_network=self.network_emission,
        observation_dim=self.obs_dim,
        name="GaussianDistributionFromMean",
        **self.config_emission)

    self.posterior_rnn = utils.build_rnn_cell(
        rnn_type="lstm", rnn_hidden_dim=32)
    self.network_posterior_mlp = utils.build_dense_network(
        [self.hidden_dim], [None])
    self.posterior_distribution = model_cavi_snlds.GaussianDistributionFromMean(
        emission_mean_network=self.network_posterior_mlp,
        observation_dim=self.hidden_dim,
        name="PosteriorDistribution",
        **self.config_inference)
    self.network_input_embedding = lambda x: x
    self.inference_network = model_cavi_snlds.RnnInferenceNetwork(
        posterior_rnn=self.posterior_rnn,
        posterior_dist=self.posterior_distribution,
        latent_dim=self.hidden_dim,
        embedding_network=self.network_input_embedding)

    self.init_z0_distribution = (
        model_cavi_snlds.construct_initial_state_distribution(
            self.hidden_dim,
            self.num_categ,
            use_triangular_cov=True))
def main(argv):
    del argv  # unused

    tf.random.set_seed(FLAGS.seed)

    timestamp = datetime.datetime.strftime(datetime.datetime.today(),
                                           "%y%m%d_%H%M%S")
    logdir = FLAGS.logdir.format(timestamp=timestamp)
    model_dir = FLAGS.model_dir.format(timestamp=timestamp)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)

    ##############################################
    # populate the flags
    ##############################################
    train_data_config = config_utils.get_data_config(
        batch_size=FLAGS.batch_size)
    test_data_config = config_utils.get_data_config(batch_size=1)

    # regularization and annealing config
    cross_entropy_config = config_utils.get_cross_entropy_config(
        decay_rate=FLAGS.xent_rate,
        decay_steps=FLAGS.xent_steps,
        initial_value=FLAGS.xent_init,
        kickin_steps=FLAGS.xent_kickin_steps,
        use_entropy_annealing=FLAGS.cross_entropy_annealing)

    learning_rate_config = config_utils.get_learning_rate_config(
        flat_learning_rate=FLAGS.flat_learning_rate,
        inverse_annealing_lr=FLAGS.use_inverse_annealing_lr,
        decay_steps=FLAGS.num_steps,
        learning_rate=FLAGS.learning_rate,
        warmup_steps=1000)

    temperature_config = config_utils.get_temperature_config(
        decay_rate=FLAGS.annealing_rate,
        decay_steps=FLAGS.annealing_steps,
        initial_temperature=FLAGS.t_init,
        minimal_temperature=FLAGS.t_min,
        kickin_steps=FLAGS.annealing_kickin_steps,
        use_temperature_annealing=FLAGS.temperature_annealing)

    # Build Dataset and Model
    train_ds = datasets.create_lorenz_attractor_by_generator(
        batch_size=train_data_config.batch_size, random_seed=FLAGS.seed)
    test_ds = datasets.create_lorenz_attractor_by_generator(
        batch_size=test_data_config.batch_size, random_seed=FLAGS.seed)

    # configuring emission distribution p(x[t] | z[t])
    config_emission = config_utils.get_distribution_config(
        triangular_cov=False, trainable_cov=False)

    emission_network = utils.build_dense_network([8, 32, 3],
                                                 ["relu", "relu", None])

    # configuring q(z[t]|h[t]=f_RNN(h[t-1], z[t-1], h[t]^b))
    config_inference = config_utils.get_distribution_config(
        triangular_cov=True)
    # the `network_posterior_rnn` is a RNN cell,
    # `h[t]=f_RNN(h[t-1], z[t-1], input[t])`,
    # which recursively takes previous step RNN states `h`, previous step
    # sampled dynamical state `z[t-1]`, and conditioned input `input[t]`.
    posterior_rnn = utils.build_rnn_cell(rnn_type=FLAGS.rnntype,
                                         rnn_hidden_dim=FLAGS.rnndim)

    # the `posterior_mlp` is a dense network emitting mean tensor for
    # the distribution of hidden states, p(z[t] | h[t])
    posterior_mlp = utils.build_dense_network([32, FLAGS.hidden_dim],
                                              ["relu", None])

    # configuring p(z[0])
    config_z_initial = config_utils.get_distribution_config(
        triangular_cov=True)

    # configuring p(z[t] | z[t-1], s[t])
    config_z_transition = config_utils.get_distribution_config(
        triangular_cov=True,
        trainable_cov=True,
        sigma_scale=0.1,
        raw_sigma_bias=1.e-5,
        sigma_min=1.e-5)

    z_transition_networks = [
        utils.build_dense_network([256, FLAGS.hidden_dim], ["relu", None])
        for _ in range(FLAGS.num_categories)
    ]

    # `network_s_transition` is a network returning the transition probability
    # `log p(s[t] |s[t-1], x[t-1])`
    num_categ_squared = FLAGS.num_categories * FLAGS.num_categories
    network_s_transition = utils.build_dense_network(
        [4 * num_categ_squared, num_categ_squared], ["relu", None])

    snlds_model = model_cavi_snlds.create_model(
        num_categ=FLAGS.num_categories,
        hidden_dim=FLAGS.hidden_dim,
        observation_dim=3,  # Lorenz attractor has input o[t] = [x, y, z].
        config_emission=config_emission,
        config_inference=config_inference,
        config_z_initial=config_z_initial,
        config_z_transition=config_z_transition,
        network_emission=emission_network,
        network_input_embedding=lambda x: x,
        network_posterior_mlp=posterior_mlp,
        network_posterior_rnn=posterior_rnn,
        network_s_transition=network_s_transition,
        networks_z_transition=z_transition_networks,
        name="snlds")

    snlds_model.build(input_shape=(FLAGS.batch_size, 200, 3))

    # learning rate decay
    def _get_learning_rate(global_step):
        """Construct Learning Rate Schedule."""
        if learning_rate_config.flat_learning_rate:
            lr_schedule = learning_rate_config.learning_rate
        elif learning_rate_config.inverse_annealing_lr:
            lr_schedule = utils.inverse_annealing_learning_rate(
                global_step, target_lr=learning_rate_config.learning_rate)
        else:
            lr_schedule = utils.learning_rate_schedule(global_step,
                                                       learning_rate_config)
        return lr_schedule

    # Learning rate for optimizer will be applied on the fly in the training loop.
    optimizer = tf.keras.optimizers.Adam()

    # temperature annealing
    def _get_temperature(step):
        """Construct Temperature Annealing Schedule."""
        if temperature_config.use_temperature_annealing:
            temperature_schedule = utils.schedule_exponential_decay(
                step, temperature_config,
                temperature_config.minimal_temperature)
        else:
            temperature_schedule = temperature_config.initial_temperature
        return temperature_schedule

    # cross entropy penalty decay
    def _get_cross_entropy_coef(step):
        """Construct Cross Entropy Coefficient Schedule."""
        if cross_entropy_config.use_entropy_annealing:
            cross_entropy_schedule = utils.schedule_exponential_decay(
                step, cross_entropy_config)
        else:
            cross_entropy_schedule = 0.
        return cross_entropy_schedule

    tensorboard_file_writer = tf.summary.create_file_writer(logdir,
                                                            flush_millis=100)

    ckpt = tf.train.Checkpoint(optimizer=optimizer, model=snlds_model)
    ckpt_manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=5)

    latest_checkpoint = tf.train.latest_checkpoint(model_dir)
    if latest_checkpoint:
        logging.info("Loading checkpoint from %s.", latest_checkpoint)
        ckpt.restore(latest_checkpoint)
    else:
        logging.info("Start training from scratch.")

    train_iter = train_ds.as_numpy_iterator()
    test_iter = test_ds.as_numpy_iterator()

    while optimizer.iterations < FLAGS.num_steps:
        learning_rate = _get_learning_rate(optimizer.iterations)
        temperature = _get_temperature(optimizer.iterations)
        cross_entropy_coef = _get_cross_entropy_coef(optimizer.iterations)
        train_metrics = train_step(train_iter.next(), snlds_model, optimizer,
                                   FLAGS.num_samples, FLAGS.objective,
                                   learning_rate, temperature,
                                   cross_entropy_coef)
        if (optimizer.iterations.numpy() % FLAGS.log_steps) == 0:
            step = optimizer.iterations.numpy()
            test_metrics = eval_step(test_iter.next(), snlds_model,
                                     FLAGS.num_samples, temperature)
            test_log_likelihood = tf.reduce_mean(test_metrics[FLAGS.objective])
            train_objective = tf.reduce_mean(train_metrics["objective"])
            logging.info("log step: %d, train loss %f, test loss %f.", step,
                         train_objective, test_log_likelihood)
            summary_items = {
                "params/learning_rate": learning_rate,
                "params/temperature": temperature,
                "params/cross_entropy_coef": cross_entropy_coef,
                "elbo/training":
                tf.reduce_mean(train_metrics[FLAGS.objective]),
                "elbo/testing": test_log_likelihood,
                "xent/training":
                tf.reduce_mean(train_metrics["cross_entropy"]),
                "xent/testing": tf.reduce_mean(test_metrics["cross_entropy"])
            }
            with tensorboard_file_writer.as_default():
                for k, v in summary_items.items():
                    tf.summary.scalar(k, v, step=step)

                original_inputs = train_metrics["inputs"][0]
                reconstructed_inputs = train_metrics["reconstructions"][0]
                most_likely_states = tf.math.argmax(
                    train_metrics["posterior_llk"],
                    axis=-1,
                    output_type=tf.int32)[0]
                hidden_states = train_metrics["sampled_z"][0]
                discrete_states_lk = tf.exp(train_metrics["posterior_llk"][0])

                # Show lorenz attractor reconstruction side-by-side with original.

                matplotlib_fig = tensorboard_utils.show_lorenz_attractor_3d(
                    fig_size=(10, 5),
                    inputs=original_inputs,
                    reconstructed_inputs=reconstructed_inputs,
                    fig_title="input_reconstruction")
                fig_numpy_array = tensorboard_utils.plot_to_image(
                    matplotlib_fig)
                tf.summary.image("Reconstruction", fig_numpy_array, step=step)

                # Show discrete state segmentation on input data along each dimension.
                matplotlib_fig = tensorboard_utils.show_lorenz_segmentation(
                    fig_size=(10, 6),
                    inputs=original_inputs,
                    segmentation=most_likely_states)
                fig_numpy_array = tensorboard_utils.plot_to_image(
                    matplotlib_fig)
                tf.summary.image("Segmentation", fig_numpy_array, step=step)

                # Show z[t] and segmentation.
                matplotlib_fig = tensorboard_utils.show_hidden_states(
                    fig_size=(12, 3),
                    zt=hidden_states,
                    segmentation=most_likely_states)
                fig_numpy_array = tensorboard_utils.plot_to_image(
                    matplotlib_fig)
                tf.summary.image("Hidden_State_zt", fig_numpy_array, step=step)

                # Show s[t] posterior likelihood.
                matplotlib_fig = tensorboard_utils.show_discrete_states(
                    fig_size=(12, 3),
                    discrete_states_lk=discrete_states_lk,
                    segmentation=most_likely_states)
                fig_numpy_array = tensorboard_utils.plot_to_image(
                    matplotlib_fig)
                tf.summary.image("Discrete_State_st",
                                 fig_numpy_array,
                                 step=step)

            ckpt_save_path = ckpt_manager.save()
            logging.info("Saving checkpoint for step %d at %s.", step,
                         ckpt_save_path)
예제 #3
0
 def test_build_dense_network(self):
   returned_nets = utils.build_dense_network(
       [3, 8, 3], ["relu", "relu", None])
   self.assertEqual(len(returned_nets.layers), 3)