def build_model_components(self): self.batch_size = 11 self.seq_len = 17 self.obs_dim = 3 self.hidden_dim = 8 self.num_categ = 5 self.config_emission = self.get_default_distribution_config() self.config_inference = self.get_default_distribution_config() self.config_z_initial = self.get_default_distribution_config() self.config_z_transition = self.get_default_distribution_config() self.network_z_transition = [ utils.build_dense_network( [3*self.hidden_dim, self.hidden_dim], ["relu", None]) for _ in range(self.num_categ)] self.z_trans_dist = model_cavi_snlds.ContinuousStateTransition( self.network_z_transition, distribution_dim=self.hidden_dim, num_categories=self.num_categ, **self.config_z_transition) num_categ_squared = self.num_categ * self.num_categ self.network_s_transition = utils.build_dense_network( [4 * num_categ_squared, num_categ_squared], ["relu", None]) self.s_trans = model_cavi_snlds.DiscreteStateTransition( transition_network=self.network_s_transition, num_categories=self.num_categ) self.network_emission = utils.build_dense_network( [4 * self.obs_dim, self.obs_dim], ["relu", None]) self.x_dist = model_cavi_snlds.GaussianDistributionFromMean( emission_mean_network=self.network_emission, observation_dim=self.obs_dim, name="GaussianDistributionFromMean", **self.config_emission) self.posterior_rnn = utils.build_rnn_cell( rnn_type="lstm", rnn_hidden_dim=32) self.network_posterior_mlp = utils.build_dense_network( [self.hidden_dim], [None]) self.posterior_distribution = model_cavi_snlds.GaussianDistributionFromMean( emission_mean_network=self.network_posterior_mlp, observation_dim=self.hidden_dim, name="PosteriorDistribution", **self.config_inference) self.network_input_embedding = lambda x: x self.inference_network = model_cavi_snlds.RnnInferenceNetwork( posterior_rnn=self.posterior_rnn, posterior_dist=self.posterior_distribution, latent_dim=self.hidden_dim, embedding_network=self.network_input_embedding) self.init_z0_distribution = ( model_cavi_snlds.construct_initial_state_distribution( self.hidden_dim, self.num_categ, use_triangular_cov=True))
def main(argv): del argv # unused tf.random.set_seed(FLAGS.seed) timestamp = datetime.datetime.strftime(datetime.datetime.today(), "%y%m%d_%H%M%S") logdir = FLAGS.logdir.format(timestamp=timestamp) model_dir = FLAGS.model_dir.format(timestamp=timestamp) if not os.path.isdir(model_dir): os.makedirs(model_dir) ############################################## # populate the flags ############################################## train_data_config = config_utils.get_data_config( batch_size=FLAGS.batch_size) test_data_config = config_utils.get_data_config(batch_size=1) # regularization and annealing config cross_entropy_config = config_utils.get_cross_entropy_config( decay_rate=FLAGS.xent_rate, decay_steps=FLAGS.xent_steps, initial_value=FLAGS.xent_init, kickin_steps=FLAGS.xent_kickin_steps, use_entropy_annealing=FLAGS.cross_entropy_annealing) learning_rate_config = config_utils.get_learning_rate_config( flat_learning_rate=FLAGS.flat_learning_rate, inverse_annealing_lr=FLAGS.use_inverse_annealing_lr, decay_steps=FLAGS.num_steps, learning_rate=FLAGS.learning_rate, warmup_steps=1000) temperature_config = config_utils.get_temperature_config( decay_rate=FLAGS.annealing_rate, decay_steps=FLAGS.annealing_steps, initial_temperature=FLAGS.t_init, minimal_temperature=FLAGS.t_min, kickin_steps=FLAGS.annealing_kickin_steps, use_temperature_annealing=FLAGS.temperature_annealing) # Build Dataset and Model train_ds = datasets.create_lorenz_attractor_by_generator( batch_size=train_data_config.batch_size, random_seed=FLAGS.seed) test_ds = datasets.create_lorenz_attractor_by_generator( batch_size=test_data_config.batch_size, random_seed=FLAGS.seed) # configuring emission distribution p(x[t] | z[t]) config_emission = config_utils.get_distribution_config( triangular_cov=False, trainable_cov=False) emission_network = utils.build_dense_network([8, 32, 3], ["relu", "relu", None]) # configuring q(z[t]|h[t]=f_RNN(h[t-1], z[t-1], h[t]^b)) config_inference = config_utils.get_distribution_config( triangular_cov=True) # the `network_posterior_rnn` is a RNN cell, # `h[t]=f_RNN(h[t-1], z[t-1], input[t])`, # which recursively takes previous step RNN states `h`, previous step # sampled dynamical state `z[t-1]`, and conditioned input `input[t]`. posterior_rnn = utils.build_rnn_cell(rnn_type=FLAGS.rnntype, rnn_hidden_dim=FLAGS.rnndim) # the `posterior_mlp` is a dense network emitting mean tensor for # the distribution of hidden states, p(z[t] | h[t]) posterior_mlp = utils.build_dense_network([32, FLAGS.hidden_dim], ["relu", None]) # configuring p(z[0]) config_z_initial = config_utils.get_distribution_config( triangular_cov=True) # configuring p(z[t] | z[t-1], s[t]) config_z_transition = config_utils.get_distribution_config( triangular_cov=True, trainable_cov=True, sigma_scale=0.1, raw_sigma_bias=1.e-5, sigma_min=1.e-5) z_transition_networks = [ utils.build_dense_network([256, FLAGS.hidden_dim], ["relu", None]) for _ in range(FLAGS.num_categories) ] # `network_s_transition` is a network returning the transition probability # `log p(s[t] |s[t-1], x[t-1])` num_categ_squared = FLAGS.num_categories * FLAGS.num_categories network_s_transition = utils.build_dense_network( [4 * num_categ_squared, num_categ_squared], ["relu", None]) snlds_model = model_cavi_snlds.create_model( num_categ=FLAGS.num_categories, hidden_dim=FLAGS.hidden_dim, observation_dim=3, # Lorenz attractor has input o[t] = [x, y, z]. config_emission=config_emission, config_inference=config_inference, config_z_initial=config_z_initial, config_z_transition=config_z_transition, network_emission=emission_network, network_input_embedding=lambda x: x, network_posterior_mlp=posterior_mlp, network_posterior_rnn=posterior_rnn, network_s_transition=network_s_transition, networks_z_transition=z_transition_networks, name="snlds") snlds_model.build(input_shape=(FLAGS.batch_size, 200, 3)) # learning rate decay def _get_learning_rate(global_step): """Construct Learning Rate Schedule.""" if learning_rate_config.flat_learning_rate: lr_schedule = learning_rate_config.learning_rate elif learning_rate_config.inverse_annealing_lr: lr_schedule = utils.inverse_annealing_learning_rate( global_step, target_lr=learning_rate_config.learning_rate) else: lr_schedule = utils.learning_rate_schedule(global_step, learning_rate_config) return lr_schedule # Learning rate for optimizer will be applied on the fly in the training loop. optimizer = tf.keras.optimizers.Adam() # temperature annealing def _get_temperature(step): """Construct Temperature Annealing Schedule.""" if temperature_config.use_temperature_annealing: temperature_schedule = utils.schedule_exponential_decay( step, temperature_config, temperature_config.minimal_temperature) else: temperature_schedule = temperature_config.initial_temperature return temperature_schedule # cross entropy penalty decay def _get_cross_entropy_coef(step): """Construct Cross Entropy Coefficient Schedule.""" if cross_entropy_config.use_entropy_annealing: cross_entropy_schedule = utils.schedule_exponential_decay( step, cross_entropy_config) else: cross_entropy_schedule = 0. return cross_entropy_schedule tensorboard_file_writer = tf.summary.create_file_writer(logdir, flush_millis=100) ckpt = tf.train.Checkpoint(optimizer=optimizer, model=snlds_model) ckpt_manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=5) latest_checkpoint = tf.train.latest_checkpoint(model_dir) if latest_checkpoint: logging.info("Loading checkpoint from %s.", latest_checkpoint) ckpt.restore(latest_checkpoint) else: logging.info("Start training from scratch.") train_iter = train_ds.as_numpy_iterator() test_iter = test_ds.as_numpy_iterator() while optimizer.iterations < FLAGS.num_steps: learning_rate = _get_learning_rate(optimizer.iterations) temperature = _get_temperature(optimizer.iterations) cross_entropy_coef = _get_cross_entropy_coef(optimizer.iterations) train_metrics = train_step(train_iter.next(), snlds_model, optimizer, FLAGS.num_samples, FLAGS.objective, learning_rate, temperature, cross_entropy_coef) if (optimizer.iterations.numpy() % FLAGS.log_steps) == 0: step = optimizer.iterations.numpy() test_metrics = eval_step(test_iter.next(), snlds_model, FLAGS.num_samples, temperature) test_log_likelihood = tf.reduce_mean(test_metrics[FLAGS.objective]) train_objective = tf.reduce_mean(train_metrics["objective"]) logging.info("log step: %d, train loss %f, test loss %f.", step, train_objective, test_log_likelihood) summary_items = { "params/learning_rate": learning_rate, "params/temperature": temperature, "params/cross_entropy_coef": cross_entropy_coef, "elbo/training": tf.reduce_mean(train_metrics[FLAGS.objective]), "elbo/testing": test_log_likelihood, "xent/training": tf.reduce_mean(train_metrics["cross_entropy"]), "xent/testing": tf.reduce_mean(test_metrics["cross_entropy"]) } with tensorboard_file_writer.as_default(): for k, v in summary_items.items(): tf.summary.scalar(k, v, step=step) original_inputs = train_metrics["inputs"][0] reconstructed_inputs = train_metrics["reconstructions"][0] most_likely_states = tf.math.argmax( train_metrics["posterior_llk"], axis=-1, output_type=tf.int32)[0] hidden_states = train_metrics["sampled_z"][0] discrete_states_lk = tf.exp(train_metrics["posterior_llk"][0]) # Show lorenz attractor reconstruction side-by-side with original. matplotlib_fig = tensorboard_utils.show_lorenz_attractor_3d( fig_size=(10, 5), inputs=original_inputs, reconstructed_inputs=reconstructed_inputs, fig_title="input_reconstruction") fig_numpy_array = tensorboard_utils.plot_to_image( matplotlib_fig) tf.summary.image("Reconstruction", fig_numpy_array, step=step) # Show discrete state segmentation on input data along each dimension. matplotlib_fig = tensorboard_utils.show_lorenz_segmentation( fig_size=(10, 6), inputs=original_inputs, segmentation=most_likely_states) fig_numpy_array = tensorboard_utils.plot_to_image( matplotlib_fig) tf.summary.image("Segmentation", fig_numpy_array, step=step) # Show z[t] and segmentation. matplotlib_fig = tensorboard_utils.show_hidden_states( fig_size=(12, 3), zt=hidden_states, segmentation=most_likely_states) fig_numpy_array = tensorboard_utils.plot_to_image( matplotlib_fig) tf.summary.image("Hidden_State_zt", fig_numpy_array, step=step) # Show s[t] posterior likelihood. matplotlib_fig = tensorboard_utils.show_discrete_states( fig_size=(12, 3), discrete_states_lk=discrete_states_lk, segmentation=most_likely_states) fig_numpy_array = tensorboard_utils.plot_to_image( matplotlib_fig) tf.summary.image("Discrete_State_st", fig_numpy_array, step=step) ckpt_save_path = ckpt_manager.save() logging.info("Saving checkpoint for step %d at %s.", step, ckpt_save_path)
def test_build_dense_network(self): returned_nets = utils.build_dense_network( [3, 8, 3], ["relu", "relu", None]) self.assertEqual(len(returned_nets.layers), 3)