def save_time_samples(self, x, t, e, name, cens=False):
        observed = e == 1
        feed_dict = {
            self.x: x,
            self.impute_mask: get_missing_mask(x, self.imputation_values),
            self.t: t,
            self.t_lab: t[observed],
            self.e: e,
            self.risk_set: risk_set(t),
            self.batch_size_tensor: len(t),
            self.is_training: False
        }

        mean, log_var = self.session.run([self.t_mu, self.t_log_var],
                                         feed_dict=feed_dict)
        predicted_time = sample_log_normal(log_var=log_var,
                                           mean=mean,
                                           sample_size=self.sample_size)
        print("predicted_time_samples:{}".format(predicted_time.shape))
        np.save('matrix/{}_{}_samples_predicted_time'.format('Test', name),
                predicted_time)
        plot_predicted_distribution(predicted=predicted_time,
                                    empirical=t,
                                    data='Test_' + name,
                                    cens=cens)
        return
    def batch_feed_dict(self, e, i, j, t, x, outcomes):
        batch_x = x[i:j, :]
        batch_t = t[i:j]
        batch_risk = risk_set(batch_t)
        batch_impute_mask = get_missing_mask(batch_x, self.imputation_values)
        batch_e = e[i:j]
        idx_observed = batch_e == 1
        feed_dict = {
            self.x: batch_x,
            self.x_lab: batch_x[idx_observed],
            self.x_unlab: batch_x[np.logical_not(idx_observed)],
            self.impute_mask: batch_impute_mask,
            self.t: batch_t,
            self.t_lab: batch_t[idx_observed],
            self.t_unlab: batch_t[np.logical_not(idx_observed)],
            self.e: batch_e,
            self.risk_set: batch_risk,
            self.batch_size_tensor: len(batch_t),
            self.is_training: False
        }
        # TODO replace with abstract methods

        updated_feed_dic = self.outcomes_function(idx=i,
                                                  j=j,
                                                  feed_dict=feed_dict,
                                                  outcomes=outcomes)
        return updated_feed_dic
Esempio n. 3
0
 def get_dict(x, t, e):
     observed_idx = e == 1
     feed_dict = {
         self.x: x,
         self.impute_mask: get_missing_mask(x, self.imputation_values),
         self.t: t,
         self.t_lab: t[observed_idx],
         self.e: e,
         self.batch_size_tensor: len(t),
         self.is_training: False,
         self.noise_alpha: np.ones(shape=self.noise_dim)
     }
     return {'feed_dict': feed_dict, 'outcomes': {}}
 def get_dict(x, t, e):
     observed_idx = e == 1
     feed_dict = {
         self.x: x,
         self.x_lab: x[observed_idx],
         self.x_unlab: x[np.logical_not(observed_idx)],
         self.impute_mask: get_missing_mask(x, self.imputation_values),
         self.t: t,
         self.t_lab: t[observed_idx],
         self.t_unlab: t[np.logical_not(observed_idx)],
         self.e: e,
         self.batch_size_tensor: len(t),
         self.is_training: False
     }
     return {'feed_dict': feed_dict, 'outcomes': {}}
Esempio n. 5
0
 def generate_time_samples(self, e, x):
     # observed = e == 1
     feed_dict = {
         self.x: x,
         self.impute_mask: get_missing_mask(x, self.imputation_values),
         # self.t: t,
         # self.t_lab: t[observed],
         self.e: e,
         # self.risk_set: risk_set(t),
         self.batch_size_tensor: len(x),
         self.is_training: False,
         self.noise_alpha: np.ones(shape=self.noise_dim)
     }
     predicted_time = []
     for p in range(self.sample_size):
         gen_time = self.session.run(self.predicted_time,
                                     feed_dict=feed_dict)
         predicted_time.append(gen_time)
     predicted_time = np.array(predicted_time)
     return predicted_time
    def train_neural_network(self):
        train_print = "Training Deep Regularized AFT Model:"
        params_print = "Parameters: l2_reg:{}, learning_rate:{}," \
                       " momentum: beta1={} beta2={}, batch_size:{}, batch_norm:{}," \
                       " hidden_dim:{}, latent_dim:{}, num_of_batches:{}, keep_prob:{}" \
            .format(self.l2_reg, self.learning_rate, self.beta1, self.beta2, self.batch_size,
                    self.batch_norm, self.hidden_dim, self.latent_dim, self.num_batches, self.keep_prob)
        print(train_print)
        print(params_print)
        logging.debug(train_print)
        logging.debug(params_print)
        self.session.run(tf.global_variables_initializer())

        best_ci = 0
        best_validation_epoch = 0
        last_improvement = 0

        start_time = time.time()
        epochs = 0
        show_all_variables()
        j = 0

        for i in range(self.num_iterations):
            # Batch Training
            run_options = tf.RunOptions(timeout_in_ms=4000)
            x_batch, t_batch, e_batch = self.session.run(
                [self.x_batch, self.t_batch, self.e_batch],
                options=run_options)
            risk_batch = risk_set(data_t=t_batch)
            batch_impute_mask = get_missing_mask(x_batch,
                                                 self.imputation_values)
            batch_size = len(t_batch)
            idx_observed = e_batch == 1
            # TODO simplify batch processing
            feed_dict_train = {
                self.x: x_batch,
                self.x_lab: x_batch[idx_observed],
                self.x_unlab: x_batch[np.logical_not(idx_observed)],
                self.impute_mask: batch_impute_mask,
                self.t: t_batch,
                self.t_lab: t_batch[idx_observed],
                self.t_unlab: t_batch[np.logical_not(idx_observed)],
                self.e: e_batch,
                self.risk_set: risk_batch,
                self.batch_size_tensor: batch_size,
                self.is_training: True
            }
            summary, train_time, train_cost, train_ranking, train_rae, train_reg, train_lik, train_recon, \
            train_obs_lik, train_censo_lik, _ = self.session.run(
                [self.merged, self.predicted_time, self.cost, self.ranking_partial_lik, self.total_rae,
                 self.reg_loss, self.neg_log_lik, self.total_t_recon_loss, self.observed_neg_lik, self.censored_neg_lik,
                 self.optimizer],
                feed_dict=feed_dict_train)
            train_ci = concordance_index(
                event_times=t_batch,
                predicted_event_times=train_time.reshape(t_batch.shape),
                event_observed=e_batch)
            tf.verify_tensor_all_finite(train_cost,
                                        "Training Cost has Nan or Infinite")
            if j >= self.num_examples:
                epochs += 1
                is_epoch = True
                # idx = 0
                j = 0
            else:
                # idx = j
                j += self.batch_size
                is_epoch = False

            if i % 100 == 0:
                train_print = "it:{}, trainCI:{}, train_ranking:{}, train_RAE:{},  train_lik:{}, train_obs_lik:{}, " \
                              "train_cens_lik:{}, train_reg:{}".format(i, train_ci, train_ranking, train_rae, train_lik,
                                                                       train_obs_lik, train_censo_lik, train_reg)
                print(train_print)
                logging.debug(train_print)

            if is_epoch or (i == (self.num_iterations - 1)):
                improved_str = ''
                # Calculate  Vaid CI the CI
                self.train_ci.append(train_ci)
                self.train_cost.append(train_cost)
                self.train_t_rae.append(train_rae)
                self.train_log_lik.append(train_lik)
                self.train_ranking.append(train_ranking)
                self.train_recon.append(train_recon)

                self.train_writer.add_summary(summary, i)
                valid_ci, valid_cost, valid_rae, valid_ranking, valid_lik, valid_reg, valid_log_var, valid_recon = self.predict_concordance_index(
                    x=self.valid_x, e=self.valid_e, t=self.valid_t)
                self.valid_cost.append(valid_cost)
                self.valid_ci.append(valid_ci)
                self.valid_t_rae.append(valid_rae)
                self.valid_log_lik.append(valid_lik)
                self.valid_ranking.append(valid_ranking)
                self.valid_recon.append(valid_recon)
                tf.verify_tensor_all_finite(
                    valid_cost, "Validation Cost has Nan or Infinite")

                if valid_ci > best_ci:
                    self.saver.save(sess=self.session,
                                    save_path=self.save_path)
                    best_validation_epoch = epochs
                    best_ci = valid_ci
                    print("valid_ci:{}".format(valid_ci))
                    last_improvement = i
                    improved_str = '*'
                    # Save  Best Perfoming all variables of the TensorFlow graph to file.
                # update best validation accuracy
                optimization_print = "Iteration: {} epochs:{}, Training: RAE:{}, Loss: {}," \
                                     " Ranking:{}, Reg:{}, Lik:{}, T_Recon:{}, CI:{}" \
                                     " Validation RAE:{} Loss:{}, Ranking:{}, Reg:{}, Lik:{}, T_Recon:{}, CI:{}, {}" \
                    .format(i + 1, epochs, train_rae, train_cost, train_ranking, train_reg, train_lik,
                            train_recon,
                            train_ci, valid_rae, valid_cost, valid_ranking, valid_reg, valid_lik, valid_recon,
                            valid_ci, improved_str)

                print(optimization_print)
                logging.debug(optimization_print)
                if i - last_improvement > self.require_improvement or math.isnan(
                        valid_cost) or epochs >= self.max_epochs:
                    print(
                        "No improvement found in a while, stopping optimization."
                    )
                    # Break out from the for-loop.
                    break
        # Ending time.

        end_time = time.time()
        time_dif = end_time - start_time
        time_dif_print = "Time usage: " + str(
            timedelta(seconds=int(round(time_dif))))
        print(time_dif_print)
        logging.debug(time_dif_print)
        # shutdown everything to avoid zombies
        self.session.run(self.queue.close(cancel_pending_enqueues=True))
        self.coord.request_stop()
        self.coord.join(self.threads)
        return best_validation_epoch, epochs