def main(): # Get arguments parsed args = get_args() # Setup for logging output_dir = 'output/{}'.format(datetime.now(timezone('Canada/Central')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3]) create_dir(output_dir) LogHelper.setup(log_path='{}/training.log'.format(output_dir), level_str='INFO') _logger = logging.getLogger(__name__) # Save the configuration for logging purpose save_yaml_config(args, path='{}/config.yaml'.format(output_dir)) # Reproducibility set_seed(args.seed) # Get dataset dataset = SyntheticDataset(args.n, args.d, args.graph_type, args.degree, args.sem_type, args.noise_scale, args.dataset_type) _logger.info('Finished generating dataset') model = NoTears(args.n, args.d, args.seed, args.l1_lambda, args.use_float64) model.print_summary(print_func=model.logger.info) trainer = ALTrainer(args.init_rho, args.rho_max, args.h_factor, args.rho_multiply, args.init_iter, args.learning_rate, args.h_tol) W_est = trainer.train(model, dataset.X, dataset.W, args.graph_thres, args.max_iter, args.iter_step, output_dir) _logger.info('Finished training model') # Save raw estimated graph, ground truth and observational data after training np.save('{}/true_graph.npy'.format(output_dir), dataset.W) np.save('{}/X.npy'.format(output_dir), dataset.X) np.save('{}/final_raw_estimated_graph.npy'.format(output_dir), W_est) # Plot raw estimated graph plot_estimated_graph(W_est, dataset.W, save_name='{}/raw_estimated_graph.png'.format(output_dir)) _logger.info('Thresholding.') # Plot thresholded estimated graph W_est[np.abs(W_est) < args.graph_thres] = 0 # Thresholding plot_estimated_graph(W_est, dataset.W, save_name='{}/thresholded_estimated_graph.png'.format(output_dir)) results_thresholded = count_accuracy(dataset.W, W_est) _logger.info('Results after thresholding by {}: {}'.format(args.graph_thres, results_thresholded))
def train_callback(self, epoch, loss, mse, h, W_true, W_est, graph_thres, output_dir): # Evaluate the learned W in each iteration after thresholding W_thresholded = np.copy(W_est) W_thresholded[np.abs(W_thresholded) < graph_thres] = 0 results_thresholded = count_accuracy(W_true, W_thresholded) self._logger.info( '[Iter {}] loss {:.3E}, mse {:.3E}, acyclic {:.3E}, shd {}, tpr {:.3f}, fdr {:.3f}, pred_size {}' .format(epoch, loss, mse, h, results_thresholded['shd'], results_thresholded['tpr'], results_thresholded['fdr'], results_thresholded['pred_size'])) # Save the raw estimated graph in each iteration create_dir('{}/raw_estimated_graph'.format(output_dir)) np.save( '{}/raw_estimated_graph/graph_iteration_{}.npy'.format( output_dir, epoch), W_est)
def main(): # Get arguments parsed args = get_args() # Setup for logging output_dir = 'output/{}'.format( datetime.now( timezone('Asia/Shanghai')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3]) create_dir(output_dir) LogHelper.setup(log_path='{}/training.log'.format(output_dir), level_str='INFO') _logger = logging.getLogger(__name__) # Save the configuration for logging purpose save_yaml_config(args, path='{}/config.yaml'.format(output_dir)) # Reproducibility set_seed(args.seed) # Get dataset dataset = RealDataset(args.batch_size) _logger.info('Finished generating dataset') device = get_device() model = VAE(args.z_dim, args.num_hidden, args.input_dim, device) trainer = Trainer(args.batch_size, args.num_epochs, args.learning_rate) trainer.train_model(model=model, dataset=dataset, output_dir=output_dir, device=device, input_dim=args.input_dim) _logger.info('Finished training model') # Visualizations samples = sample_vae(model, args.z_dim, device) plot_samples(samples) plot_reconstructions(model, dataset, device) _logger.info('All Finished!')
def main(): # Get arguments parsed args = get_args() # Setup for logging output_dir = 'output/{}'.format( datetime.now( timezone('Asia/Hong_Kong')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3]) create_dir(output_dir) LogHelper.setup(log_path='{}/training.log'.format(output_dir), level_str='INFO') _logger = logging.getLogger(__name__) # Save the configuration for logging purpose save_yaml_config(args, path='{}/config.yaml'.format(output_dir)) # Reproducibility set_seed(args.seed) # Get dataset dataset = SyntheticDataset(args.n, args.d, args.graph_type, args.degree, args.sem_type, args.noise_scale, args.dataset_type, args.x_dim) _logger.info('Finished generating dataset') model = GAE(args.n, args.d, args.x_dim, args.seed, args.num_encoder_layers, args.num_decoder_layers, args.hidden_size, args.latent_dim, args.l1_graph_penalty, args.use_float64) model.print_summary(print_func=model.logger.info) trainer = ALTrainer(args.init_rho, args.rho_thres, args.h_thres, args.rho_multiply, args.init_iter, args.learning_rate, args.h_tol, args.early_stopping, args.early_stopping_thres) W_est = trainer.train(model, dataset.X, dataset.W, args.graph_thres, args.max_iter, args.iter_step, output_dir) _logger.info('Finished training model') # Save raw recovered graph, ground truth and observational data after training np.save('{}/true_graph.npy'.format(output_dir), dataset.W) np.save('{}/observational_data.npy'.format(output_dir), dataset.X) np.save('{}/final_raw_recovered_graph.npy'.format(output_dir), W_est) # Plot raw recovered graph plot_recovered_graph( W_est, dataset.W, save_name='{}/raw_recovered_graph.png'.format(output_dir)) _logger.info('Filter by constant threshold') W_est = W_est / np.max(np.abs(W_est)) # Normalize # Plot thresholded recovered graph W_est[np.abs(W_est) < args.graph_thres] = 0 # Thresholding plot_recovered_graph( W_est, dataset.W, save_name='{}/thresholded_recovered_graph.png'.format(output_dir)) results_thresholded = count_accuracy(dataset.W, W_est) _logger.info('Results after thresholding by {}: {}'.format( args.graph_thres, results_thresholded))
def synthetic(): np.set_printoptions(precision=3) # Get arguments parsed args = get_args() # Setup for logging output_dir = 'output/{}'.format( datetime.now( timezone('Asia/Shanghai')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3]) create_dir(output_dir) LogHelper.setup(log_path='{}/training.log'.format(output_dir), level_str='INFO') _logger = logging.getLogger(__name__) # Save the configuration for logging purpose save_yaml_config(args, path='{}/config.yaml'.format(output_dir)) # Reproducibility set_seed(args.seed) # Get dataset dataset = SyntheticDataset(args.num_X, args.num_Z, args.num_samples, args.max_lag) # Save dataset dataset.save_dataset(output_dir=output_dir) _logger.info('Finished generating dataset') # Look at data _logger.info('The shape of observed data: {}'.format(dataset.X.shape)) plot_timeseries(dataset.X[-150:], 'X', display_mode=False, save_name=output_dir + '/timeseries_X.png') plot_timeseries(dataset.Z[-150:], 'Z', display_mode=False, save_name=output_dir + '/timeseries_Z.png') # Init model model = TimeLatent(args.num_X, args.max_lag, args.num_samples, args.device, args.prior_rho_A, args.prior_sigma_W, args.temperature, args.sigma_Z, args.sigma_X) trainer = Trainer(args.learning_rate, args.num_iterations, args.num_output) trainer.train_model(model=model, X=torch.tensor(dataset.X, dtype=torch.float32, device=args.device), output_dir=output_dir) plot_losses(trainer.train_losses, display_mode=False, save_name=output_dir + '/loss.png') # Save result trainer.log_and_save_intermediate_outputs() _logger.info('Finished training model') # Calculate performance estimate_A = model.posterior_A.probs[:, :args.num_X, :args.num_X].cpu( ).data.numpy( ) # model.posterior_A.probs is shape with (max_lag,num_X+num_Z,num_X+num_Z) groudtruth_A = np.array( dataset.groudtruth) # groudtruth is shape with (max_lag,num_X,num_X) Score = AUC_score(estimate_A.T, groudtruth_A.T) _logger.info( '\n fpr:{} \n tpr:{}\n thresholds:{}\n AUC:{}'.format( Score['fpr'], Score['tpr'], Score['thresholds'], Score['AUC'])) plot_ROC_curve(estimate_A.T, groudtruth_A.T, display_mode=False, save_name=output_dir + '/ROC_Curve.png') for t in range(0, 11): _logger.info('Under threshold:{}'.format(t / 10)) _logger.info(F1(estimate_A.T, groudtruth_A.T, threshold=t / 10)) estimate_A_all = model.posterior_A.probs.cpu().data.numpy() # Visualizations for k in range(args.max_lag): # Note that in our implementation, A_ij=1 means j->i, but in the plot_recovered_graph A_ij=1 means i->j, so transpose A plot_recovered_graph(estimate_A[k].T, groudtruth_A[k].T, title='Lag = {}'.format(k + 1), display_mode=False, save_name=output_dir + '/A_lag_{}.png'.format(k)) plot_recovered_graph(estimate_A_all[k].T, dataset.A[k].T, title='Lag = {}'.format(k + 1), display_mode=False, save_name=output_dir + '/All_lag_{}.png'.format(k)) _logger.info('All Finished!')
def real(): np.set_printoptions(precision=3) # Get arguments parsed args = get_args() # Setup for logging output_dir = 'output/real_{}'.format( datetime.now( timezone('Asia/Shanghai')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3]) create_dir(output_dir) LogHelper.setup(log_path='{}/training.log'.format(output_dir), level_str='INFO') _logger = logging.getLogger('real') # Save the configuration for logging purpose save_yaml_config(args, path='{}/config.yaml'.format(output_dir)) # Reproducibility set_seed(args.seed) # Get dataset dataset = RealDataset() # Look at data _logger.info('The shape of observed data: {}'.format(dataset.stock.shape)) plot_timeseries(dataset.stock[-150:], 'stock', display_mode=False, save_name=output_dir + '/timeseries_stock.png') # Set parameters num_samples, num_X = dataset.stock.shape temperature = 2.0 max_lag = 1 prior_rho_A = 0.7 prior_sigma_W = 0.05 sigma_Z = 1.0 sigma_X = 0.05 num_iterations = 3000 # Log the parameters _logger.info( "num_X:{},max_lag:{},num_samples:{},args.device:{},prior_rho_A:{},prior_sigma_W:{},temperature:{},sigma_Z:{},sigma_X:{},num_iterations:{}" .format(num_X, max_lag, num_samples, args.device, prior_rho_A, prior_sigma_W, temperature, sigma_Z, sigma_X, num_iterations)) # Init model model = TimeLatent(num_X=num_X, max_lag=max_lag, num_samples=num_samples, device=args.device, prior_rho_A=prior_rho_A, prior_sigma_W=prior_sigma_W, temperature=temperature, sigma_Z=sigma_Z, sigma_X=sigma_X) trainer = Trainer(learning_rate=args.learning_rate, num_iterations=num_iterations, num_output=args.num_output) trainer.train_model(model=model, X=torch.tensor(dataset.stock, dtype=torch.float32, device=args.device), output_dir=output_dir) plot_losses(trainer.train_losses, display_mode=False, save_name=output_dir + '/loss.png') # Save result trainer.log_and_save_intermediate_outputs() _logger.info('Finished training model') estimate_A = model.posterior_A.probs.cpu().data.numpy() # Visualizations for k in range(max_lag): # Note that in our implementation, A_ij=1 means j->i, but in the plot_recovered_graph A_ij=1 means i->j, so transpose A plot_recovered_graph(estimate_A[k].T, W=None, title='Lag = {}'.format(k + 1), display_mode=False, save_name=output_dir + '/lag_{}.png'.format(k)) _logger.info('All Finished!')
def save(self, model_dir): create_dir(model_dir) self.saver.save(self.sess, '{}/model'.format(model_dir))
def main(): # Setup for output directory and logging output_dir = 'output/{}'.format( datetime.now( timezone('Asia/Hong_Kong')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3]) create_dir(output_dir) LogHelper.setup(log_path='{}/training.log'.format(output_dir), level_str='INFO') _logger = logging.getLogger(__name__) _logger.info('Python version is {}'.format(platform.python_version())) _logger.info('Current commit of code: ___') # Get running configuration config, _ = get_config() config.save_model_path = '{}/model'.format(output_dir) # config.restore_model_path = '{}/model'.format(output_dir) config.summary_dir = '{}/summary'.format(output_dir) config.plot_dir = '{}/plot'.format(output_dir) config.graph_dir = '{}/graph'.format(output_dir) # Create directory create_dir(config.summary_dir) create_dir(config.summary_dir) create_dir(config.plot_dir) create_dir(config.graph_dir) # Reproducibility set_seed(config.seed) # Log the configuration parameters _logger.info('Configuration parameters: {}'.format( vars(config))) # Use vars to convert config to dict for logging if config.read_data: file_path = '{}/data.npy'.format(config.data_path) solution_path = '{}/DAG.npy'.format(config.data_path) training_set = DataGenerator_read_data(file_path, solution_path, config.normalize, config.transpose) else: raise ValueError("Only support importing data from existing files") # set penalty weights score_type = config.score_type reg_type = config.reg_type if config.lambda_flag_default: sl, su, strue = BIC_lambdas(training_set.inputdata, None, None, training_set.true_graph.T, reg_type, score_type) lambda1 = 0 lambda1_upper = 5 lambda1_update_add = 1 lambda2 = 1 / (10**(np.round(config.max_length / 3))) lambda2_upper = 0.01 lambda2_update_mul = 10 lambda_iter_num = config.lambda_iter_num # test initialized score _logger.info('Original sl: {}, su: {}, strue: {}'.format( sl, su, strue)) _logger.info('Transfomed sl: {}, su: {}, lambda2: {}, true: {}'.format( sl, su, lambda2, (strue - sl) / (su - sl) * lambda1_upper)) else: # test choices for the case with manually provided bounds # not fully tested sl = config.score_lower su = config.score_upper if config.score_bd_tight: lambda1 = 2 lambda1_upper = 2 else: lambda1 = 0 lambda1_upper = 5 lambda1_update_add = 1 lambda2 = 1 / (10**(np.round(config.max_length / 3))) lambda2_upper = 0.01 lambda2_update_mul = config.lambda2_update lambda_iter_num = config.lambda_iter_num # actor actor = Actor(config) callreward = get_Reward(actor.batch_size, config.max_length, actor.input_dimension, training_set.inputdata, sl, su, lambda1_upper, score_type, reg_type, config.l1_graph_reg, False) _logger.info( 'Finished creating training dataset, actor model and reward class') # Saver to save & restore all the variables. variables_to_save = [ v for v in tf.global_variables() if 'Adam' not in v.name ] saver = tf.train.Saver(var_list=variables_to_save, keep_checkpoint_every_n_hours=1.0) _logger.info('Starting session...') sess_config = tf.ConfigProto(log_device_placement=False) sess_config.gpu_options.allow_growth = True with tf.Session(config=sess_config) as sess: # Run initialize op sess.run(tf.global_variables_initializer()) # Test tensor shape _logger.info('Shape of actor.input: {}'.format( sess.run(tf.shape(actor.input_)))) # _logger.info('training_set.true_graph: {}'.format(training_set.true_graph)) # _logger.info('training_set.b: {}'.format(training_set.b)) # Initialize useful variables rewards_avg_baseline = [] rewards_batches = [] reward_max_per_batch = [] lambda1s = [] lambda2s = [] graphss = [] probsss = [] max_rewards = [] max_reward = float('-inf') image_count = 0 accuracy_res = [] accuracy_res_pruned = [] max_reward_score_cyc = (lambda1_upper + 1, 0) # Summary writer writer = tf.summary.FileWriter(config.summary_dir, sess.graph) _logger.info('Starting training.') for i in (range(1, config.nb_epoch + 1)): if config.verbose: _logger.info('Start training for {}-th epoch'.format(i)) input_batch = training_set.train_batch(actor.batch_size, actor.max_length, actor.input_dimension) graphs_feed = sess.run(actor.graphs, feed_dict={actor.input_: input_batch}) reward_feed = callreward.cal_rewards( graphs_feed, lambda1, lambda2) #BTBT 得到生成的一个graph batch中每个graph的reward # max reward, max reward per batch max_reward = -callreward.update_scores( [max_reward_score_cyc], lambda1, lambda2 )[0] #BTBT 每跑batch我们都会用更新后的lambda1/2对截至到目前找到的max_reward对应的max_reward_score_cyc再算一次rewaerd作为新的max_reward max_reward_batch = float('inf') max_reward_batch_score_cyc = (0, 0) for reward_, score_, cyc_ in reward_feed: if reward_ < max_reward_batch: max_reward_batch = reward_ max_reward_batch_score_cyc = (score_, cyc_) max_reward_batch = -max_reward_batch if max_reward < max_reward_batch: #BTBT 若该batch的所有graph中有一个reward大于max_reward,该reward就作为max_reward max_reward = max_reward_batch max_reward_score_cyc = max_reward_batch_score_cyc # for average reward per batch reward_batch_score_cyc = np.mean(reward_feed[:, 1:], axis=0) if config.verbose: _logger.info( 'Finish calculating reward for current batch of graph') # Get feed dict feed = { actor.input_: input_batch, actor.reward_: -reward_feed[:, 0], actor.graphs_: graphs_feed } summary, base_op, score_test, probs, graph_batch, \ reward_batch, reward_avg_baseline, train_step1, train_step2 = sess.run([ actor.merged, actor.base_op, actor.test_scores, actor.log_softmax, actor.graph_batch, actor.reward_batch, actor.avg_baseline, actor.train_step1, actor.train_step2], feed_dict=feed) if config.verbose: _logger.info( 'Finish updating actor and critic network using reward calculated' ) lambda1s.append(lambda1) lambda2s.append(lambda2) rewards_avg_baseline.append(reward_avg_baseline) rewards_batches.append(reward_batch_score_cyc) reward_max_per_batch.append(max_reward_batch_score_cyc) graphss.append(graph_batch) probsss.append(probs) max_rewards.append(max_reward_score_cyc) # logging if i == 1 or i % 500 == 0: if i >= 500: writer.add_summary(summary, i) _logger.info( '[iter {}] reward_batch: {}, max_reward: {}, max_reward_batch: {}' .format(i, reward_batch, max_reward, max_reward_batch)) # other logger info; uncomment if you want to check # _logger.info('graph_batch_avg: {}'.format(graph_batch)) # _logger.info('graph true: {}'.format(training_set.true_graph)) # _logger.info('graph weights true: {}'.format(training_set.b)) # _logger.info('=====================================') plt.figure(1) plt.plot(rewards_batches, label='reward per batch') plt.plot(max_rewards, label='max reward') plt.legend() plt.savefig('{}/reward_batch_average.png'.format( config.plot_dir)) plt.close() image_count += 1 # this draw the average graph per batch. # can be modified to draw the graph (with or w/o pruning) that has the best reward fig = plt.figure(2) fig.suptitle('Iteration: {}'.format(i)) ax = fig.add_subplot(1, 2, 1) ax.set_title('recovered_graph') ax.imshow(np.around(graph_batch.T).astype(int), cmap=plt.cm.gray) ax = fig.add_subplot(1, 2, 2) ax.set_title('ground truth') ax.imshow(training_set.true_graph, cmap=plt.cm.gray) plt.savefig('{}/recovered_graph_iteration_{}.png'.format( config.plot_dir, image_count)) plt.close() # update lambda1, lamda2 if (i + 1) % lambda_iter_num == 0: ls_kv = callreward.update_all_scores(lambda1, lambda2) # np.save('{}/solvd_dict_epoch_{}.npy'.format(config.graph_dir, i), np.array(ls_kv)) max_rewards_re = callreward.update_scores( max_rewards, lambda1, lambda2) rewards_batches_re = callreward.update_scores( rewards_batches, lambda1, lambda2) reward_max_per_batch_re = callreward.update_scores( reward_max_per_batch, lambda1, lambda2) # saved somewhat more detailed logging info np.save('{}/solvd_dict.npy'.format(config.graph_dir), np.array(ls_kv)) pd.DataFrame(np.array(max_rewards_re)).to_csv( '{}/max_rewards.csv'.format(output_dir)) pd.DataFrame(rewards_batches_re).to_csv( '{}/rewards_batch.csv'.format(output_dir)) pd.DataFrame(reward_max_per_batch_re).to_csv( '{}/reward_max_batch.csv'.format(output_dir)) pd.DataFrame(lambda1s).to_csv( '{}/lambda1s.csv'.format(output_dir)) pd.DataFrame(lambda2s).to_csv( '{}/lambda2s.csv'.format(output_dir)) graph_int, score_min, cyc_min = np.int32( ls_kv[0][0]), ls_kv[0][1][1], ls_kv[0][1][-1] if cyc_min < 1e-5: lambda1_upper = score_min lambda1 = min(lambda1 + lambda1_update_add, lambda1_upper) lambda2 = min(lambda2 * lambda2_update_mul, lambda2_upper) _logger.info( '[iter {}] lambda1 {}, upper {}, lambda2 {}, upper {}, score_min {}, cyc_min {}' .format(i + 1, lambda1, lambda1_upper, lambda2, lambda2_upper, score_min, cyc_min)) graph_batch = convert_graph_int_to_adj_mat(graph_int) #BTBT ??? graph prune 怎么做? if reg_type == 'LR': graph_batch_pruned = np.array( graph_prunned_by_coef(graph_batch, training_set.inputdata)) elif reg_type == 'QR': graph_batch_pruned = np.array( graph_prunned_by_coef_2nd(graph_batch, training_set.inputdata)) elif reg_type == 'GPR': # The R codes of CAM pruning operates the graph form that (i,j)=1 indicates i-th node-> j-th node # so we need to do a tranpose on the input graph and another tranpose on the output graph graph_batch_pruned = np.transpose( pruning_cam(training_set.inputdata, np.array(graph_batch).T)) # estimate accuracy acc_est = count_accuracy(training_set.true_graph, graph_batch.T) acc_est2 = count_accuracy(training_set.true_graph, graph_batch_pruned.T) fdr, tpr, fpr, shd, nnz = acc_est['fdr'], acc_est['tpr'], acc_est['fpr'], acc_est['shd'], \ acc_est['pred_size'] fdr2, tpr2, fpr2, shd2, nnz2 = acc_est2['fdr'], acc_est2['tpr'], acc_est2['fpr'], acc_est2['shd'], \ acc_est2['pred_size'] accuracy_res.append((fdr, tpr, fpr, shd, nnz)) accuracy_res_pruned.append((fdr2, tpr2, fpr2, shd2, nnz2)) np.save('{}/accuracy_res.npy'.format(output_dir), np.array(accuracy_res)) np.save('{}/accuracy_res2.npy'.format(output_dir), np.array(accuracy_res_pruned)) _logger.info( 'before pruning: fdr {}, tpr {}, fpr {}, shd {}, nnz {}'. format(fdr, tpr, fpr, shd, nnz)) _logger.info( 'after pruning: fdr {}, tpr {}, fpr {}, shd {}, nnz {}'. format(fdr2, tpr2, fpr2, shd2, nnz2)) # Save the variables to disk if i % max(1, int(config.nb_epoch / 5)) == 0 and i != 0: curr_model_path = saver.save(sess, '{}/tmp.ckpt'.format( config.save_model_path), global_step=i) _logger.info('Model saved in file: {}'.format(curr_model_path)) _logger.info('Training COMPLETED !') saver.save(sess, '{}/actor.ckpt'.format(config.save_model_path))
def train(self, model, X, W, graph_thres, max_iter, iter_step, output_dir): """ model object should contain the several class member: - sess - train_op - loss - mse_loss - h - W_prime - X - rho - alpha - lr """ # Create directory to save the raw recovered graph in each iteration create_dir('{}/raw_recovered_graph'.format(output_dir)) model.sess.run(tf.global_variables_initializer()) rho, alpha, h, h_new = self.init_rho, 0.0, np.inf, np.inf prev_W_est, prev_mse = None, float('inf') self._logger.info( 'Started training for {} iterations'.format(max_iter)) for i in range(1, max_iter + 1): while rho < self.rho_thres: self._logger.info('rho {:.3E}, alpha {:.3E}'.format( rho, alpha)) loss_new, mse_new, h_new, W_new = self.train_step( model, iter_step, X, rho, alpha) if h_new > self.h_thres * h: rho *= self.rho_multiply else: break if self.early_stopping: if mse_new / prev_mse > self.early_stopping_thres and h_new <= 1e-7: # MSE increases too much, revert back to original graph and perform early stopping # Only perform this early stopping when h_new is sufficiently small # (i.e., at least smaller than 1e-7) return prev_W_est else: prev_W_est = W_new prev_mse = mse_new # Intermediate outputs self.log_and_save_intermediate_outputs(i, W, W_new, graph_thres, loss_new, mse_new, h_new, output_dir) W_est, h = W_new, h_new alpha += rho * h if h <= self.h_tol and i > self.init_iter: self._logger.info( 'Early stopping at {}-th iteration'.format(i)) break # Save model model_dir = '{}/model/'.format(output_dir) model.save(model_dir) self._logger.info('Model saved to {}'.format(model_dir)) return W_est