예제 #1
0
    def train(self, n_rows, n_cols, rows, cols, vals, n_factors, d_pairwise, hidden_layer_sizes,
              n_iterations, batch_size, holdout_ratio, learning_rate, reg_param, l2_param,
              root_savedir, root_logdir,
              no_train_metric=False, seed=None):

        """
        Training routine.

        :param n_rows: Number of rows
        :param n_cols: Number of cols
        :param rows:
        :param cols:
        :param vals:
        :param n_factors: Number of non-bilinear terms
        :param d_pairwise: Number of bilinear terms
        :param hidden_layer_sizes:
        :param n_iterations:
        :param batch_size:
        :param holdout_ratio:
        :param learning_rate:
        :param reg_param: Frobenius norm regularization terms for the features
        :param l2_param: L2 regularization parameter for the nnet weights
        :param root_savedir:
        :param root_logdir:
        :param no_train_metric:
        :param seed:
        :return:
        """

        self.n_rows = n_rows
        self.n_cols = n_cols
        self.n_factors = n_factors
        self.d_pairwise = d_pairwise
        self.hidden_layer_sizes = hidden_layer_sizes
        self.reg_param = reg_param
        self.l2_param = l2_param

        if not os.path.exists(root_savedir):
            os.makedirs(root_savedir)

        ###  Data handling  ###

        # here we only train on positive examples, so all pairs are only the "on" values
        pairs = np.vstack([rows, cols, vals]).T  # (3, n_obs)
        batch_generator = BatchGenerator(pairs, batch_size, holdout_ratio=holdout_ratio, seed=seed)


        ###  Construct the TF graph  ###

        self.construct_graph()

        all_vars = tf.trainable_variables()
        latent_vars = [self.U, self.V, self.Up, self.Vp]  # the inputs to the nnets
        nnet_vars = [x for x in all_vars if x not in latent_vars]  # the nnet variables

        print("\nlatent vars:", latent_vars)
        print("\nnnet vars:", nnet_vars)

        train_lvars = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss, var_list=latent_vars)
        train_nnet = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss, var_list=nnet_vars)

        ###  Training  ###

        if not no_train_metric:
            train_loss = tf.placeholder(dtype=tf.float32, shape=[], name='train_loss')
            train_loss_summary = tf.summary.scalar('train_loss', train_loss)

        if holdout_ratio is not None:
            test_mse = tf.placeholder(dtype=tf.float32, shape=[], name='test_mse')
            test_mse_summary = tf.summary.scalar('test_mse', test_mse)

        # create tensorboard summary objects
        scalar_summaries = [tf.summary.scalar(var_.name, var_) for var_ in all_vars if len(var_.shape) == 0]
        array_summaries = [tf.summary.histogram(var_.name, var_) for var_ in all_vars if len(var_.shape) > 0]

        writer = tf.summary.FileWriter(root_logdir)

        saver = tf.train.Saver()
        init = tf.global_variables_initializer()

        with tf.Session() as sess:

            init.run()

            if not no_train_metric:
                train_dict = {self.row: batch_generator.train[:, 0],
                              self.col: batch_generator.train[:, 1],
                              self.val: batch_generator.train[:, 2]}

            if holdout_ratio is not None:
                test_dict = {self.row: batch_generator.test[:, 0],
                             self.col: batch_generator.test[:, 1],
                             self.val: batch_generator.test[:, 2]}


            for iteration in range(n_iterations):

                batch = batch_generator.next_batch()
                batch_dict = {self.row: batch[:, 0],
                              self.col: batch[:, 1],
                              self.val: batch[:, 2]}

                # alternate between optimizing inputs and nnet vars
                sess.run(train_lvars, feed_dict=batch_dict)
                sess.run(train_nnet, feed_dict=batch_dict)

                if iteration % 20 == 0:

                    print(iteration, end="")

                    if not no_train_metric:
                        train_loss_ = sess.run(self.loss, feed_dict=train_dict)
                        train_loss_summary_str = sess.run(train_loss_summary, feed_dict={train_loss: train_loss_})
                        writer.add_summary(train_loss_summary_str, iteration)
                        print("\ttrain loss: %.4f" % train_loss_, end="")


                    if holdout_ratio is not None:
                        test_sse_ = sess.run(self.sse, feed_dict=test_dict)
                        test_mse_ = test_sse_ / len(batch_generator.test)
                        test_mse_summary_str = sess.run(test_mse_summary, feed_dict={test_mse: test_mse_})
                        writer.add_summary(test_mse_summary_str, iteration)
                        print("\ttest mse: %.4f" % test_mse_)

                    scalar_summaries_str = sess.run(scalar_summaries)
                    array_summaries_str = sess.run(array_summaries)
                    for summary_ in scalar_summaries_str + array_summaries_str:
                        writer.add_summary(summary_, iteration)

            # save the model
            saver.save(sess, os.path.join(root_savedir, "model.ckpt"))

        # close the file writer
        writer.close()
예제 #2
0
def main():
    args = config_train()
    
    # Specifying location to store model, best model and tensorboard log.
    args.save_model = os.path.join(args.output_dir, 'save_model/model')
    args.save_best_model = os.path.join(args.output_dir, 'best_model/model')
    args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/')
    args.vocab_file = ''
    
    # Create necessary directories.
    if len(args.init_dir) != 0:
        args.output_dir = args.init_dir
    else:
        if os.path.exists(args.output_dir):
            shutil.rmtree(args.output_dir)
        for paths in [args.save_model, args.save_best_model, args.tb_log_dir]:
            os.makedirs(os.path.dirname(paths))
        
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s', 
                        level=logging.INFO, datefmt='%I:%M:%S')
    
    print('=' * 60)
    print('All final and intermediate outputs will be stored in %s/' % args.output_dir)
    print('=' * 60 + '\n')
    
    if args.debug:
        logging.info('args are:\n%s', args)
        
    if len(args.init_dir) != 0:
        with open(os.path.join(args.init_dir, 'result.json'), 'r') as f:
            result = json.load(f)
        params = result['params']
        args.init_model = result['latest_model']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            args.encoding = result['encoding']
        else:
            args.encoding = 'utf-8'
        args.vocab_file = os.path.join(args.init_dir, 'vocab.json')
    else:
        params = {'batch_size': args.batch_size,
                  'num_unrollings': args.num_unrollings,
                  'hidden_size': args.hidden_size,
                  'max_grad_norm': args.max_grad_norm,
                  'embedding_size': args.embedding_size,
                  'num_layers': args.num_layers,
                  'learning_rate': args.learning_rate,
                  'model': args.model,
                  'dropout': args.dropout,
                  'input_dropout': args.input_dropout}
        best_model = ''
    logging.info('Parameters are:\n%s\n', json.dumps(params, sort_keys=True, indent=4))
    
    # Read and split data.
    logging.info('Reading data from: %s', args.data_file)
    with codecs.open(args.data_file, 'r', encoding=args.encoding) as f:
        text = f.read()
    
    if args.test:
        text = text[:50000]
    logging.info('Number of characters: %s', len(text))
    
    if args.debug:
        logging.info('First %d characters: %s', 10, text[:10])
    
    logging.info('Creating train, valid, test split')
    train_size = int(args.train_frac * len(text))
    valid_size = int(args.valid_frac * len(text))
    test_size = len(text) - train_size - valid_size
    train_text = text[:train_size]
    valid_text = text[train_size:train_size + valid_size]
    test_text = text[train_size + valid_size:]
    
    vocab_loader = VocabularyLoader()
    if len(args.vocab_file) != 0:
        vocab_loader.load_vocab(args.vocab_file, args.encoding)
    else:
        logging.info('Creating vocabulary')
        vocab_loader.create_vocab(text)
        vocab_file = os.path.join(args.output_dir, 'vocab.json')
        vocab_loader.save_vocab(vocab_file, args.encoding)
        logging.info('Vocabulary is saved in %s', vocab_file)
        args.vocab_file = vocab_file
    
    params['vocab_size'] = vocab_loader.vocab_size
    logging.info('Vocab size: %d', vocab_loader.vocab_size)
    
    # Create batch generators.
    batch_size = params['batch_size']
    num_unrollings = params['num_unrollings']
    
    train_batches = BatchGenerator(vocab_loader.vocab_index_dict, train_text, batch_size, num_unrollings)
    valid_batches = BatchGenerator(vocab_loader.vocab_index_dict, valid_text, batch_size, num_unrollings)
    test_batches = BatchGenerator(vocab_loader.vocab_index_dict, test_text, batch_size, num_unrollings)
    
    if args.debug:
        logging.info('Test batch generators')
        x, y = train_batches.next_batch()
        logging.info((str(x[0]), str(batche2string(x[0], vocab_loader.index_vocab_dict))))
        logging.info((str(y[0]), str(batche2string(y[0], vocab_loader.index_vocab_dict))))
        
    # Create graphs
    logging.info('Creating graph')
    graph = tf.Graph()
    with graph.as_default():
        with tf.name_scope('training'):
            train_model = CharRNNLM(is_training=True, infer=False, **params)
        tf.get_variable_scope().reuse_variables()
        with tf.name_scope('validation'):
            valid_model = CharRNNLM(is_training=False, infer=False, **params)
        with tf.name_scope('evaluation'):
            test_model = CharRNNLM(is_training=False, infer=False, **params)
            saver = tf.train.Saver(name='model_saver')
            best_model_saver = tf.train.Saver(name='best_model_saver')
    
    logging.info('Start training\n')
    
    result = {}
    result['params'] = params
    result['vocab_file'] = args.vocab_file
    result['encoding'] = args.encoding
    
    try:
        with tf.Session(graph=graph) as session:
            # Version 8 changed the api of summary writer to use
            # graph instead of graph_def.
            if TF_VERSION >= 8:
                graph_info = session.graph
            else:
                graph_info = session.graph_def
            
            train_writer = tf.train.SummaryWriter(args.tb_log_dir + 'train/', graph_info)
            valid_writer = tf.train.SummaryWriter(args.tb_log_dir + 'valid/', graph_info)
            
            # load a saved model or start from random initialization.
            if len(args.init_model) != 0:
                saver.restore(session, args.init_model)
            else:
                tf.initialize_all_variables().run()
            
            learning_rate = args.learning_rate
            for epoch in range(args.num_epochs):
                logging.info('=' * 19 + ' Epoch %d ' + '=' * 19 + '\n', epoch)
                logging.info('Training on training set')
                # training step
                ppl, train_summary_str, global_step = train_model.run_epoch(session, train_batches, is_training=True,
                                     learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
                # record the summary
                train_writer.add_summary(train_summary_str, global_step)
                train_writer.flush()
                # save model
                saved_path = saver.save(session, args.save_model,
                                        global_step=train_model.global_step)
                                        
                logging.info('Latest model saved in %s\n', saved_path)
                logging.info('Evaluate on validation set')
                
                valid_ppl, valid_summary_str, _ = valid_model.run_epoch(session, valid_batches, is_training=False,
                                     learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
                
                # save and update best model
                if (len(best_model) == 0) or (valid_ppl < best_valid_ppl):
                    best_model = best_model_saver.save(session, args.save_best_model, 
                                                       global_step=train_model.global_step)
                    best_valid_ppl = valid_ppl
                else:
                    learning_rate /= 2.0
                    logging.info('Decay the learning rate: ' + str(learning_rate))
                
                valid_writer.add_summary(valid_summary_str, global_step)
                valid_writer.flush()
                
                logging.info('Best model is saved in %s', best_model)
                logging.info('Best validation ppl is %f\n', best_valid_ppl)
                
                result['latest_model'] = saved_path
                result['best_model'] = best_model
                # Convert to float because numpy.float is not json serializable.
                result['best_valid_ppl'] = float(best_valid_ppl)
                
                result_path = os.path.join(args.output_dir, 'result.json')
                if os.path.exists(result_path):
                    os.remove(result_path)
                with open(result_path, 'w') as f:
                    json.dump(result, f, indent=2, sort_keys=True)
                
            logging.info('Latest model is saved in %s', saved_path)
            logging.info('Best model is saved in %s', best_model)
            logging.info('Best validation ppl is %f\n', best_valid_ppl)
            
            logging.info('Evaluate the best model on test set')
            saver.restore(session, best_model)
            test_ppl, _, _ = test_model.run_epoch(session, test_batches, is_training=False,
                                     learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
            result['test_ppl'] = float(test_ppl)
    finally:
        result_path = os.path.join(args.output_dir, 'result.json')
        if os.path.exists(result_path):
            os.remove(result_path)
        with open(result_path, 'w') as f:
            json.dump(result, f, indent=2, sort_keys=True)
예제 #3
0
    def train(self, n_rows, n_cols, rows, cols, vals, n_factors, d_pairwise, hidden_layer_sizes,
              n_iterations, batch_size, holdout_ratio, learning_rate, n_samples,
              root_savedir, root_logdir,
              no_train_metric=False, seed=None):

        """
        Training routine.

        :param n_rows: Number of rows
        :param n_cols:
        :param rows: Rows for "on" entries
        :param cols: Corresponding columns for "on" entries
        :param vals:
        :param n_factors: Number of non-bilinear terms
        :param d_pairwise: Number of bilinear terms
        :param hidden_layer_sizes:
        :param n_iterations:
        :param batch_size:
        :param holdout_ratio:
        :param learning_rate:
        :param n_samples:
        :param root_savedir:
        :param root_logdir:
        :param no_train_metric:
        :param seed:
        :return:
        """

        self.n_rows = n_rows
        self.n_cols = n_cols
        self.n_factors = n_factors
        self.d_pairwise = d_pairwise
        self.hidden_layer_sizes = hidden_layer_sizes

        if not os.path.exists(root_savedir):
            os.makedirs(root_savedir)

        ###  Data handling  ###

        pairs = np.vstack([rows, cols, vals]).T  # (3, n_obs)
        batch_generator = BatchGenerator(pairs, batch_size, holdout_ratio=holdout_ratio, seed=seed)


        ###  Construct the TF graph  ###

        self.construct_graph()

        train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(-self.elbo)

        ###  Training  ###

        if not no_train_metric:
            train_elbo = tf.placeholder(dtype=tf.float32, shape=[], name='train_elbo')
            train_elbo_summary = tf.summary.scalar('train_elbo', train_elbo)

            train_ll = tf.placeholder(dtype=tf.float32, shape=[], name='train_ll')
            train_ll_summary = tf.summary.scalar('train_ll', train_ll)

        if holdout_ratio is not None:
            test_ll = tf.placeholder(dtype=tf.float32, shape=[], name='test_ll')
            test_ll_summary = tf.summary.scalar('test_ll', test_ll)

        # create tensorboard summary objects
        all_vars = tf.trainable_variables()
        scalar_summaries = [tf.summary.scalar(var_.name, var_) for var_ in all_vars if len(var_.shape) == 0]
        array_summaries = [tf.summary.histogram(var_.name, var_) for var_ in all_vars if len(var_.shape) > 0]

        writer = tf.summary.FileWriter(root_logdir)

        saver = tf.train.Saver()
        init = tf.global_variables_initializer()

        with tf.Session() as sess:

            init.run()

            if not no_train_metric:
                train_dict = {self.row: batch_generator.train[:, 0],
                              self.col: batch_generator.train[:, 1],
                              self.val: batch_generator.train[:, 2],
                              self.n_samples: 100,
                              self.batch_scale: 1.0}

            if holdout_ratio is not None:
                test_dict = {self.row: batch_generator.test[:, 0],
                             self.col: batch_generator.test[:, 1],
                             self.val: batch_generator.test[:, 2],
                             self.n_samples: 100,
                             self.batch_scale: 1.0}


            for iteration in range(n_iterations):

                batch = batch_generator.next_batch()
                sess.run(train_op, feed_dict={self.row: batch[:, 0],
                                              self.col: batch[:, 1],
                                              self.val: batch[:, 2],
                                              self.n_samples: n_samples,
                                              self.batch_scale: len(batch_generator.train) / len(batch)
                                              })

                if iteration % 20 == 0:

                    print(iteration, end="")

                    if not no_train_metric:
                        train_ll_, train_elbo_ = sess.run([self.data_loglikel, self.elbo], feed_dict=train_dict)
                        train_ll_summary_str, train_elbo_summary_str = sess.run([train_ll_summary, train_elbo_summary],
                                                                                feed_dict={train_ll: train_ll_,
                                                                                           train_elbo: train_elbo_})
                        writer.add_summary(train_ll_summary_str, iteration)
                        writer.add_summary(train_elbo_summary_str, iteration)
                        print("\tTrain ELBO: %.4f" % train_elbo_, end="")
                        print("\tTrain LL: %.4f" % train_ll_, end="")

                    if holdout_ratio is not None:
                        test_ll_ = sess.run(self.data_loglikel, feed_dict=test_dict)
                        test_ll_summary_str = sess.run(test_ll_summary, feed_dict={test_ll: test_ll_})
                        writer.add_summary(test_ll_summary_str, iteration)
                        print("\tTest LL: %.4f" % test_ll_)

                    scalar_summaries_str = sess.run(scalar_summaries)
                    array_summaries_str = sess.run(array_summaries)
                    for summary_ in scalar_summaries_str + array_summaries_str:
                        writer.add_summary(summary_, iteration)

            # save the model
            saver.save(sess, os.path.join(root_savedir, "model.ckpt"))

        # close the file writer
        writer.close()
    def train(self,
              N,
              row,
              col,
              T,
              n_features,
              n_pairwise_features,
              hidden_layer_sizes,
              n_iterations,
              batch_size,
              n_samples,
              holdout_ratio_valid,
              learning_rate,
              root_savedir,
              log_interval=10,
              no_train_metric=False,
              seed=None,
              debug=False):
        """
        Training routine.

        Note about the data: the (row, col) tuples of the ON (i.e., one-valued) entries of the graph are to be passed,
        and they should correspond to the upper triangle of the graph. (Recall we do not allow self-links.) Regardless,
        the code will make a symmetric graph out of all passed entries (within the upper triangular or not) and only the
        upper triangle of the resulting matrix will be kept.

        :param N: Number of nodes in the graph.
        :param row: row indices corresponding to the ON entries (in the upper triangle).
        :param col: col indices corresponding to the ON entries (in the upper triangle).
        :param T: Truncation level for the DP.
        :param n_features:
        :param hidden_layer_sizes:
        :param n_iterations:
        :param batch_size: HALF the minibatch size. In particular, we will always add the symmetric entry in the graph
            (i.e., the corresponding entry in the lower triangle) in the minibatch.
        :param n_samples:
        :param holdout_ratio_valid:
        :param learning_rate:
        :param root_savedir:
        :param no_train_metric:
        :param seed:
        :param debug:
        :return:
        """
        self.N = N
        self.T = T
        self.n_features = n_features
        self.n_pairwise_features = n_pairwise_features
        self.hidden_layer_sizes = hidden_layer_sizes

        if not os.path.exists(root_savedir):
            os.makedirs(root_savedir)

        # Data handling.
        X_sp = sp.csr_matrix((np.ones(len(row)), (row, col)), shape=[N, N])
        X_sp = X_sp + X_sp.transpose()
        X_sp = sp.triu(X_sp, k=1)
        row, col = X_sp.nonzero()

        pairs = get_pairs(N, row, col)
        pairs = pairs.astype(int)

        batch_generator = BatchGenerator(pairs,
                                         batch_size,
                                         holdout_ratio=holdout_ratio_valid,
                                         seed=seed)

        # Construct the TF graph.
        self.construct_graph()
        all_vars = tf.trainable_variables()
        print("\nTrainable variables:")
        pprint([var_.name for var_ in all_vars])

        train_op = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(-self.elbo)

        ###  Create q(Z) variational parameters  ###

        # before this was uniformly initialized
        # self.qZ_ = np.ones([N, T]) / T
        self.qZ_ = np.random.dirichlet(np.ones(T), size=N)  # (N, T)

        # the following quantity needs to be passed to the TF graph and must be updated after every update to qZ
        sum_qZ_above = np.zeros([N, T - 1])
        for k in range(T - 1):
            sum_qZ_above[:, k] = np.sum(self.qZ_[:, k + 1:], axis=1)

        # Training.
        if not no_train_metric:
            train_elbo = tf.placeholder(dtype=tf.float32,
                                        shape=[],
                                        name='train_elbo')
            train_elbo_summary = tf.summary.scalar('train_elbo', train_elbo)

            train_ll = tf.placeholder(dtype=tf.float32,
                                      shape=[],
                                      name='train_ll')
            train_ll_summary = tf.summary.scalar('train_ll', train_ll)

        if holdout_ratio_valid is not None:
            test_ll = tf.placeholder(dtype=tf.float32,
                                     shape=[],
                                     name='test_ll')
            test_ll_summary = tf.summary.scalar('test_ll', test_ll)

        # Grab all scalar variables, to track in Tensorboard.
        trainable_vars = tf.trainable_variables()
        scalar_summaries = [
            tf.summary.scalar(tensor_.name, tensor_)
            for tensor_ in trainable_vars if len(tensor_.shape) == 0
        ]
        tensor_summaries = [
            tf.summary.histogram(tensor_.name, tensor_)
            for tensor_ in trainable_vars if len(tensor_.shape) > 0
        ]

        root_logdir = os.path.join(root_savedir, "tf_logs")
        writer = tf.summary.FileWriter(root_logdir)

        saver = tf.train.Saver()
        init = tf.global_variables_initializer()

        with tf.Session() as sess:
            init.run()

            if not no_train_metric:

                # add symmetric entries from the lower triangle
                train_data = batch_generator.train
                row = np.concatenate([train_data[:, 0], train_data[:, 1]])
                col = np.concatenate([train_data[:, 1], train_data[:, 0]])
                val = np.concatenate([train_data[:, 2], train_data[:, 2]])
                train_dict = {
                    self.row: row,
                    self.col: col,
                    self.val: val,
                    self.batch_scale: 1.0
                }

            if holdout_ratio_valid is not None:
                test_data = batch_generator.test
                row = np.concatenate([test_data[:, 0], test_data[:, 1]])
                col = np.concatenate([test_data[:, 1], test_data[:, 0]])
                val = np.concatenate([test_data[:, 2], test_data[:, 2]])
                test_dict = {
                    self.row: row,
                    self.col: col,
                    self.val: val,
                    self.batch_scale: 1.0
                }

            logging.info("Starting training...")
            for iteration in range(n_iterations):

                batch = batch_generator.next_batch()
                batch_dict = {
                    self.row: np.concatenate([batch[:, 0], batch[:, 1]]),
                    self.col: np.concatenate([batch[:, 1], batch[:, 0]]),
                    self.val: np.concatenate([batch[:, 2], batch[:, 2]]),
                    self.qZ: self.qZ_,
                    self.n_samples: n_samples,
                    self.batch_scale: len(pairs) / len(batch),
                    self.sum_qZ_above: sum_qZ_above,
                }

                # make a gradient update
                sess.run(train_op, feed_dict=batch_dict)

                # analytically
                self.update_qZ(sess=sess,
                               batch=batch,
                               n_samples=n_samples,
                               debug=debug)

                # this update to sum_qZ_above was done at the beginning of the iteration. this implementation updates the sum_qZ_above before
                # logging the intermediate loss functions, and also one more time before saving the model. this actually makes more sense to me.
                # we could also just add this computation inside the construct graph function? it would have to be recomputed a few times more, but makes the code cleaner
                for k in range(T - 1):
                    sum_qZ_above[:, k] = np.sum(self.qZ_[:, k + 1:], axis=1)

                if iteration % log_interval == 0:

                    # Add scalar variables to Tensorboard.
                    for summ_str in sess.run(scalar_summaries):
                        writer.add_summary(summ_str, iteration)
                    # Add tensor variables to Tensorboard.
                    for summ_str in sess.run(tensor_summaries):
                        writer.add_summary(summ_str, iteration)

                    if not no_train_metric:
                        train_dict.update({
                            self.qZ: self.qZ_,
                            self.sum_qZ_above: sum_qZ_above,
                            self.n_samples: 100
                        })
                        train_ll_, train_elbo_ = sess.run(
                            [self.data_loglikel, self.elbo],
                            feed_dict=train_dict)
                        train_ll_summary_str, train_elbo_summary_str = sess.run(
                            [train_ll_summary, train_elbo_summary],
                            feed_dict={
                                train_ll: train_ll_,
                                train_elbo: train_elbo_
                            })
                        writer.add_summary(train_ll_summary_str, iteration)
                        writer.add_summary(train_elbo_summary_str, iteration)

                    if holdout_ratio_valid is not None:
                        test_dict.update({
                            self.qZ: self.qZ_,
                            self.sum_qZ_above: sum_qZ_above,
                            self.n_samples: 100
                        })
                        test_ll_ = sess.run(self.data_loglikel,
                                            feed_dict=test_dict)
                        test_ll_summary_str = sess.run(
                            test_ll_summary, feed_dict={test_ll: test_ll_})
                        writer.add_summary(test_ll_summary_str, iteration)

                    # Log training overview.
                    log_str = "%-4d" % iteration
                    if not no_train_metric:
                        log_str += "  ELBO: %.4e  Train ll: %.4e" % (
                            train_elbo_, train_ll_)
                    if holdout_ratio_valid is not None:
                        log_str += "  Valid ll: %.4e" % test_ll_
                    logging.info(log_str)

            # save the model
            saver.save(sess, os.path.join(root_savedir, "model.ckpt"))

        # close the file writer
        writer.close()