Esempio n. 1
0
def main():
    # Get arguments parsed
    args = get_args()

    # Setup for logging
    output_dir = 'output/{}'.format(datetime.now(timezone('Canada/Central')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3])
    create_dir(output_dir)
    LogHelper.setup(log_path='{}/training.log'.format(output_dir), level_str='INFO')
    _logger = logging.getLogger(__name__)

    # Save the configuration for logging purpose
    save_yaml_config(args, path='{}/config.yaml'.format(output_dir))

    # Reproducibility
    set_seed(args.seed)

    # Get dataset
    dataset = SyntheticDataset(args.n, args.d, args.graph_type, args.degree, args.sem_type,
                               args.noise_scale, args.dataset_type)
    _logger.info('Finished generating dataset')

    model = NoTears(args.n, args.d, args.seed, args.l1_lambda, args.use_float64)
    model.print_summary(print_func=model.logger.info)

    trainer = ALTrainer(args.init_rho, args.rho_max, args.h_factor, args.rho_multiply,
                        args.init_iter, args.learning_rate, args.h_tol)
    W_est = trainer.train(model, dataset.X, dataset.W, args.graph_thres,
                          args.max_iter, args.iter_step, output_dir)
    _logger.info('Finished training model')

    # Save raw estimated graph, ground truth and observational data after training
    np.save('{}/true_graph.npy'.format(output_dir), dataset.W)
    np.save('{}/X.npy'.format(output_dir), dataset.X)
    np.save('{}/final_raw_estimated_graph.npy'.format(output_dir), W_est)

    # Plot raw estimated graph
    plot_estimated_graph(W_est, dataset.W,
                         save_name='{}/raw_estimated_graph.png'.format(output_dir))

    _logger.info('Thresholding.')
    # Plot thresholded estimated graph
    W_est[np.abs(W_est) < args.graph_thres] = 0    # Thresholding
    plot_estimated_graph(W_est, dataset.W,
                         save_name='{}/thresholded_estimated_graph.png'.format(output_dir))
    results_thresholded = count_accuracy(dataset.W, W_est)
    _logger.info('Results after thresholding by {}: {}'.format(args.graph_thres, results_thresholded))
Esempio n. 2
0
    def train_callback(self, epoch, loss, mse, h, W_true, W_est, graph_thres,
                       output_dir):
        # Evaluate the learned W in each iteration after thresholding
        W_thresholded = np.copy(W_est)
        W_thresholded[np.abs(W_thresholded) < graph_thres] = 0
        results_thresholded = count_accuracy(W_true, W_thresholded)

        self._logger.info(
            '[Iter {}] loss {:.3E}, mse {:.3E}, acyclic {:.3E}, shd {}, tpr {:.3f}, fdr {:.3f}, pred_size {}'
            .format(epoch, loss, mse, h, results_thresholded['shd'],
                    results_thresholded['tpr'], results_thresholded['fdr'],
                    results_thresholded['pred_size']))

        # Save the raw estimated graph in each iteration
        create_dir('{}/raw_estimated_graph'.format(output_dir))
        np.save(
            '{}/raw_estimated_graph/graph_iteration_{}.npy'.format(
                output_dir, epoch), W_est)
Esempio n. 3
0
    def log_and_save_intermediate_outputs(self, i, W_true, W_est, graph_thres,
                                          loss, mse, h, output_dir):
        # Evaluate the learned W in each iteration after thresholding
        W_thresholded = np.copy(W_est)
        W_thresholded = W_thresholded / np.max(np.abs(W_thresholded))
        W_thresholded[np.abs(W_thresholded) < graph_thres] = 0
        results = count_accuracy(W_true, W_thresholded)

        # Logging
        self._logger.info(
            '[Iter {}] loss {:.3E}, mse {:.3E}, acyclic {:.3E}, shd {}, tpr {:.3f}, fdr {:.3f}, pred_size {}'
            .format(i, loss, mse, h, results['shd'], results['tpr'],
                    results['fdr'], results['pred_size']))

        # Save the raw recovered graph in each iteration
        np.save(
            '{}/raw_recovered_graph/graph_iteration_{}.npy'.format(
                output_dir, i), W_est)
Esempio n. 4
0
def main():
    # Get arguments parsed
    args = get_args()

    # Setup for logging
    output_dir = 'output/{}'.format(
        datetime.now(
            timezone('Asia/Hong_Kong')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3])
    create_dir(output_dir)
    LogHelper.setup(log_path='{}/training.log'.format(output_dir),
                    level_str='INFO')
    _logger = logging.getLogger(__name__)

    # Save the configuration for logging purpose
    save_yaml_config(args, path='{}/config.yaml'.format(output_dir))

    # Reproducibility
    set_seed(args.seed)

    # Get dataset
    dataset = SyntheticDataset(args.n, args.d, args.graph_type, args.degree,
                               args.sem_type, args.noise_scale,
                               args.dataset_type, args.x_dim)
    _logger.info('Finished generating dataset')

    model = GAE(args.n, args.d, args.x_dim, args.seed, args.num_encoder_layers,
                args.num_decoder_layers, args.hidden_size, args.latent_dim,
                args.l1_graph_penalty, args.use_float64)
    model.print_summary(print_func=model.logger.info)

    trainer = ALTrainer(args.init_rho, args.rho_thres, args.h_thres,
                        args.rho_multiply, args.init_iter, args.learning_rate,
                        args.h_tol, args.early_stopping,
                        args.early_stopping_thres)
    W_est = trainer.train(model, dataset.X, dataset.W, args.graph_thres,
                          args.max_iter, args.iter_step, output_dir)
    _logger.info('Finished training model')

    # Save raw recovered graph, ground truth and observational data after training
    np.save('{}/true_graph.npy'.format(output_dir), dataset.W)
    np.save('{}/observational_data.npy'.format(output_dir), dataset.X)
    np.save('{}/final_raw_recovered_graph.npy'.format(output_dir), W_est)

    # Plot raw recovered graph
    plot_recovered_graph(
        W_est,
        dataset.W,
        save_name='{}/raw_recovered_graph.png'.format(output_dir))

    _logger.info('Filter by constant threshold')
    W_est = W_est / np.max(np.abs(W_est))  # Normalize

    # Plot thresholded recovered graph
    W_est[np.abs(W_est) < args.graph_thres] = 0  # Thresholding
    plot_recovered_graph(
        W_est,
        dataset.W,
        save_name='{}/thresholded_recovered_graph.png'.format(output_dir))
    results_thresholded = count_accuracy(dataset.W, W_est)
    _logger.info('Results after thresholding by {}: {}'.format(
        args.graph_thres, results_thresholded))
Esempio n. 5
0
def main():
    # Setup for output directory and logging
    output_dir = 'output/{}'.format(
        datetime.now(
            timezone('Asia/Hong_Kong')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3])
    create_dir(output_dir)
    LogHelper.setup(log_path='{}/training.log'.format(output_dir),
                    level_str='INFO')
    _logger = logging.getLogger(__name__)
    _logger.info('Python version is {}'.format(platform.python_version()))
    _logger.info('Current commit of code: ___')

    # Get running configuration
    config, _ = get_config()
    config.save_model_path = '{}/model'.format(output_dir)
    # config.restore_model_path = '{}/model'.format(output_dir)
    config.summary_dir = '{}/summary'.format(output_dir)
    config.plot_dir = '{}/plot'.format(output_dir)
    config.graph_dir = '{}/graph'.format(output_dir)

    # Create directory
    create_dir(config.summary_dir)
    create_dir(config.summary_dir)
    create_dir(config.plot_dir)
    create_dir(config.graph_dir)

    # Reproducibility
    set_seed(config.seed)

    # Log the configuration parameters
    _logger.info('Configuration parameters: {}'.format(
        vars(config)))  # Use vars to convert config to dict for logging

    if config.read_data:
        file_path = '{}/data.npy'.format(config.data_path)
        solution_path = '{}/DAG.npy'.format(config.data_path)
        training_set = DataGenerator_read_data(file_path, solution_path,
                                               config.normalize,
                                               config.transpose)
    else:
        raise ValueError("Only support importing data from existing files")

    # set penalty weights
    score_type = config.score_type
    reg_type = config.reg_type

    if config.lambda_flag_default:

        sl, su, strue = BIC_lambdas(training_set.inputdata, None, None,
                                    training_set.true_graph.T, reg_type,
                                    score_type)

        lambda1 = 0
        lambda1_upper = 5
        lambda1_update_add = 1
        lambda2 = 1 / (10**(np.round(config.max_length / 3)))
        lambda2_upper = 0.01
        lambda2_update_mul = 10
        lambda_iter_num = config.lambda_iter_num

        # test initialized score
        _logger.info('Original sl: {}, su: {}, strue: {}'.format(
            sl, su, strue))
        _logger.info('Transfomed sl: {}, su: {}, lambda2: {}, true: {}'.format(
            sl, su, lambda2, (strue - sl) / (su - sl) * lambda1_upper))

    else:
        # test choices for the case with manually provided bounds
        # not fully tested

        sl = config.score_lower
        su = config.score_upper
        if config.score_bd_tight:
            lambda1 = 2
            lambda1_upper = 2
        else:
            lambda1 = 0
            lambda1_upper = 5
            lambda1_update_add = 1
        lambda2 = 1 / (10**(np.round(config.max_length / 3)))
        lambda2_upper = 0.01
        lambda2_update_mul = config.lambda2_update
        lambda_iter_num = config.lambda_iter_num

    # actor
    actor = Actor(config)

    callreward = get_Reward(actor.batch_size, config.max_length,
                            actor.input_dimension, training_set.inputdata, sl,
                            su, lambda1_upper, score_type, reg_type,
                            config.l1_graph_reg, False)

    _logger.info(
        'Finished creating training dataset, actor model and reward class')

    # Saver to save & restore all the variables.
    variables_to_save = [
        v for v in tf.global_variables() if 'Adam' not in v.name
    ]
    saver = tf.train.Saver(var_list=variables_to_save,
                           keep_checkpoint_every_n_hours=1.0)

    _logger.info('Starting session...')
    sess_config = tf.ConfigProto(log_device_placement=False)
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        # Run initialize op
        sess.run(tf.global_variables_initializer())

        # Test tensor shape
        _logger.info('Shape of actor.input: {}'.format(
            sess.run(tf.shape(actor.input_))))
        # _logger.info('training_set.true_graph: {}'.format(training_set.true_graph))
        # _logger.info('training_set.b: {}'.format(training_set.b))

        # Initialize useful variables
        rewards_avg_baseline = []
        rewards_batches = []
        reward_max_per_batch = []

        lambda1s = []
        lambda2s = []

        graphss = []
        probsss = []
        max_rewards = []
        max_reward = float('-inf')
        image_count = 0

        accuracy_res = []
        accuracy_res_pruned = []

        max_reward_score_cyc = (lambda1_upper + 1, 0)

        # Summary writer
        writer = tf.summary.FileWriter(config.summary_dir, sess.graph)

        _logger.info('Starting training.')

        for i in (range(1, config.nb_epoch + 1)):

            if config.verbose:
                _logger.info('Start training for {}-th epoch'.format(i))

            input_batch = training_set.train_batch(actor.batch_size,
                                                   actor.max_length,
                                                   actor.input_dimension)
            graphs_feed = sess.run(actor.graphs,
                                   feed_dict={actor.input_: input_batch})
            reward_feed = callreward.cal_rewards(
                graphs_feed, lambda1,
                lambda2)  #BTBT 得到生成的一个graph batch中每个graph的reward

            # max reward, max reward per batch
            max_reward = -callreward.update_scores(
                [max_reward_score_cyc], lambda1, lambda2
            )[0]  #BTBT 每跑batch我们都会用更新后的lambda1/2对截至到目前找到的max_reward对应的max_reward_score_cyc再算一次rewaerd作为新的max_reward
            max_reward_batch = float('inf')
            max_reward_batch_score_cyc = (0, 0)

            for reward_, score_, cyc_ in reward_feed:
                if reward_ < max_reward_batch:
                    max_reward_batch = reward_
                    max_reward_batch_score_cyc = (score_, cyc_)

            max_reward_batch = -max_reward_batch

            if max_reward < max_reward_batch:  #BTBT 若该batch的所有graph中有一个reward大于max_reward,该reward就作为max_reward
                max_reward = max_reward_batch
                max_reward_score_cyc = max_reward_batch_score_cyc

            # for average reward per batch
            reward_batch_score_cyc = np.mean(reward_feed[:, 1:], axis=0)

            if config.verbose:
                _logger.info(
                    'Finish calculating reward for current batch of graph')

            # Get feed dict
            feed = {
                actor.input_: input_batch,
                actor.reward_: -reward_feed[:, 0],
                actor.graphs_: graphs_feed
            }

            summary, base_op, score_test, probs, graph_batch, \
                reward_batch, reward_avg_baseline, train_step1, train_step2 = sess.run([
                    actor.merged,
                    actor.base_op,
                    actor.test_scores,
                    actor.log_softmax,
                    actor.graph_batch,
                    actor.reward_batch,
                    actor.avg_baseline,
                    actor.train_step1,
                    actor.train_step2], feed_dict=feed)

            if config.verbose:
                _logger.info(
                    'Finish updating actor and critic network using reward calculated'
                )

            lambda1s.append(lambda1)
            lambda2s.append(lambda2)

            rewards_avg_baseline.append(reward_avg_baseline)
            rewards_batches.append(reward_batch_score_cyc)
            reward_max_per_batch.append(max_reward_batch_score_cyc)

            graphss.append(graph_batch)
            probsss.append(probs)
            max_rewards.append(max_reward_score_cyc)

            # logging
            if i == 1 or i % 500 == 0:
                if i >= 500:
                    writer.add_summary(summary, i)

                _logger.info(
                    '[iter {}] reward_batch: {}, max_reward: {}, max_reward_batch: {}'
                    .format(i, reward_batch, max_reward, max_reward_batch))
                # other logger info; uncomment if you want to check
                # _logger.info('graph_batch_avg: {}'.format(graph_batch))
                # _logger.info('graph true: {}'.format(training_set.true_graph))
                # _logger.info('graph weights true: {}'.format(training_set.b))
                # _logger.info('=====================================')

                plt.figure(1)
                plt.plot(rewards_batches, label='reward per batch')
                plt.plot(max_rewards, label='max reward')
                plt.legend()
                plt.savefig('{}/reward_batch_average.png'.format(
                    config.plot_dir))
                plt.close()

                image_count += 1
                # this draw the average graph per batch.
                # can be modified to draw the graph (with or w/o pruning) that has the best reward
                fig = plt.figure(2)
                fig.suptitle('Iteration: {}'.format(i))
                ax = fig.add_subplot(1, 2, 1)
                ax.set_title('recovered_graph')
                ax.imshow(np.around(graph_batch.T).astype(int),
                          cmap=plt.cm.gray)
                ax = fig.add_subplot(1, 2, 2)
                ax.set_title('ground truth')
                ax.imshow(training_set.true_graph, cmap=plt.cm.gray)
                plt.savefig('{}/recovered_graph_iteration_{}.png'.format(
                    config.plot_dir, image_count))
                plt.close()

            # update lambda1, lamda2
            if (i + 1) % lambda_iter_num == 0:
                ls_kv = callreward.update_all_scores(lambda1, lambda2)
                # np.save('{}/solvd_dict_epoch_{}.npy'.format(config.graph_dir, i), np.array(ls_kv))
                max_rewards_re = callreward.update_scores(
                    max_rewards, lambda1, lambda2)
                rewards_batches_re = callreward.update_scores(
                    rewards_batches, lambda1, lambda2)
                reward_max_per_batch_re = callreward.update_scores(
                    reward_max_per_batch, lambda1, lambda2)

                # saved somewhat more detailed logging info
                np.save('{}/solvd_dict.npy'.format(config.graph_dir),
                        np.array(ls_kv))
                pd.DataFrame(np.array(max_rewards_re)).to_csv(
                    '{}/max_rewards.csv'.format(output_dir))
                pd.DataFrame(rewards_batches_re).to_csv(
                    '{}/rewards_batch.csv'.format(output_dir))
                pd.DataFrame(reward_max_per_batch_re).to_csv(
                    '{}/reward_max_batch.csv'.format(output_dir))
                pd.DataFrame(lambda1s).to_csv(
                    '{}/lambda1s.csv'.format(output_dir))
                pd.DataFrame(lambda2s).to_csv(
                    '{}/lambda2s.csv'.format(output_dir))

                graph_int, score_min, cyc_min = np.int32(
                    ls_kv[0][0]), ls_kv[0][1][1], ls_kv[0][1][-1]

                if cyc_min < 1e-5:
                    lambda1_upper = score_min
                lambda1 = min(lambda1 + lambda1_update_add, lambda1_upper)
                lambda2 = min(lambda2 * lambda2_update_mul, lambda2_upper)
                _logger.info(
                    '[iter {}] lambda1 {}, upper {}, lambda2 {}, upper {}, score_min {}, cyc_min {}'
                    .format(i + 1, lambda1, lambda1_upper, lambda2,
                            lambda2_upper, score_min, cyc_min))

                graph_batch = convert_graph_int_to_adj_mat(graph_int)
                #BTBT ??? graph prune 怎么做?
                if reg_type == 'LR':
                    graph_batch_pruned = np.array(
                        graph_prunned_by_coef(graph_batch,
                                              training_set.inputdata))
                elif reg_type == 'QR':
                    graph_batch_pruned = np.array(
                        graph_prunned_by_coef_2nd(graph_batch,
                                                  training_set.inputdata))
                elif reg_type == 'GPR':
                    # The R codes of CAM pruning operates the graph form that (i,j)=1 indicates i-th node-> j-th node
                    # so we need to do a tranpose on the input graph and another tranpose on the output graph
                    graph_batch_pruned = np.transpose(
                        pruning_cam(training_set.inputdata,
                                    np.array(graph_batch).T))

                # estimate accuracy
                acc_est = count_accuracy(training_set.true_graph,
                                         graph_batch.T)
                acc_est2 = count_accuracy(training_set.true_graph,
                                          graph_batch_pruned.T)

                fdr, tpr, fpr, shd, nnz = acc_est['fdr'], acc_est['tpr'], acc_est['fpr'], acc_est['shd'], \
                                          acc_est['pred_size']
                fdr2, tpr2, fpr2, shd2, nnz2 = acc_est2['fdr'], acc_est2['tpr'], acc_est2['fpr'], acc_est2['shd'], \
                                               acc_est2['pred_size']

                accuracy_res.append((fdr, tpr, fpr, shd, nnz))
                accuracy_res_pruned.append((fdr2, tpr2, fpr2, shd2, nnz2))

                np.save('{}/accuracy_res.npy'.format(output_dir),
                        np.array(accuracy_res))
                np.save('{}/accuracy_res2.npy'.format(output_dir),
                        np.array(accuracy_res_pruned))

                _logger.info(
                    'before pruning: fdr {}, tpr {}, fpr {}, shd {}, nnz {}'.
                    format(fdr, tpr, fpr, shd, nnz))
                _logger.info(
                    'after  pruning: fdr {}, tpr {}, fpr {}, shd {}, nnz {}'.
                    format(fdr2, tpr2, fpr2, shd2, nnz2))

            # Save the variables to disk
            if i % max(1, int(config.nb_epoch / 5)) == 0 and i != 0:
                curr_model_path = saver.save(sess,
                                             '{}/tmp.ckpt'.format(
                                                 config.save_model_path),
                                             global_step=i)
                _logger.info('Model saved in file: {}'.format(curr_model_path))

        _logger.info('Training COMPLETED !')
        saver.save(sess, '{}/actor.ckpt'.format(config.save_model_path))
Esempio n. 6
0
def rlbic(d, npx, npy, lambda_iter_num=1000, nb_epoch=20000):
    # reset graph
    tf.reset_default_graph()
    # FIXME set the mx_length at the beginning. This might be used even in RLBIC
    #
    # but I'm not sure if the RLBIC code is already initialized with above settings
    config.max_length = d
    config.lambda_iter_num = lambda_iter_num
    config.nb_epoch = nb_epoch

    training_set = DataGenerator(npx, npy, False, True)
    score_type = 'BIC'
    # FIXME what is this?
    reg_type = 'LR'
    sl, su, strue = BIC_lambdas(training_set.inputdata, None, None,
                                training_set.true_graph.T, reg_type,
                                score_type)

    lambda1 = 0
    lambda1_upper = 5
    lambda1_update_add = 1
    lambda2 = 1 / (10**(np.round(config.max_length / 3)))
    lambda2_upper = 0.01
    lambda2_update_mul = 10
    lambda_iter_num = config.lambda_iter_num

    actor = Actor(config)
    callreward = get_Reward(actor.batch_size, config.max_length,
                            actor.input_dimension, training_set.inputdata, sl,
                            su, lambda1_upper, score_type, reg_type,
                            config.l1_graph_reg, False)

    sess_config = tf.ConfigProto(log_device_placement=False)
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        # Run initialize op
        sess.run(tf.global_variables_initializer())
        rewards_avg_baseline = []
        rewards_batches = []
        reward_max_per_batch = []

        lambda1s = []
        lambda2s = []

        graphss = []
        probsss = []
        max_rewards = []
        max_reward = float('-inf')
        image_count = 0

        accuracy_res = []
        accuracy_res_pruned = []

        max_reward_score_cyc = (lambda1_upper + 1, 0)

        result = None
        print('Starting training.')
        for i in tqdm(range(1, config.nb_epoch + 1)):
            # print('Start training for {}-th epoch'.format(i))
            input_batch = training_set.train_batch(actor.batch_size,
                                                   actor.max_length,
                                                   actor.input_dimension)
            graphs_feed = sess.run(actor.graphs,
                                   feed_dict={actor.input_: input_batch})
            reward_feed = callreward.cal_rewards(graphs_feed, lambda1, lambda2)

            # max reward, max reward per batch
            max_reward = -callreward.update_scores([max_reward_score_cyc],
                                                   lambda1, lambda2)[0]
            max_reward_batch = float('inf')
            max_reward_batch_score_cyc = (0, 0)

            for reward_, score_, cyc_ in reward_feed:
                if reward_ < max_reward_batch:
                    max_reward_batch = reward_
                    max_reward_batch_score_cyc = (score_, cyc_)

            max_reward_batch = -max_reward_batch

            if max_reward < max_reward_batch:
                max_reward = max_reward_batch
                max_reward_score_cyc = max_reward_batch_score_cyc

            # for average reward per batch
            reward_batch_score_cyc = np.mean(reward_feed[:, 1:], axis=0)
            # print('Finish calculating reward for current batch of graph')

            # Get feed dict
            feed = {
                actor.input_: input_batch,
                actor.reward_: -reward_feed[:, 0],
                actor.graphs_: graphs_feed
            }

            summary, base_op, score_test, probs, graph_batch, \
                reward_batch, reward_avg_baseline, train_step1, train_step2 = sess.run([actor.merged, actor.base_op,
                actor.test_scores, actor.log_softmax, actor.graph_batch, actor.reward_batch, actor.avg_baseline, actor.train_step1,
                actor.train_step2], feed_dict=feed)

            # print('Finish updating actor and critic network using reward calculated')
            lambda1s.append(lambda1)
            lambda2s.append(lambda2)

            rewards_avg_baseline.append(reward_avg_baseline)
            rewards_batches.append(reward_batch_score_cyc)
            reward_max_per_batch.append(max_reward_batch_score_cyc)

            graphss.append(graph_batch)
            probsss.append(probs)
            max_rewards.append(max_reward_score_cyc)

            if i == 1 or i % 500 == 0:
                print(
                    '[iter {}] reward_batch: {}, max_reward: {}, max_reward_batch: {}'
                    .format(i, reward_batch, max_reward, max_reward_batch))
            # update lambda1, lamda2
            if (i + 1) % lambda_iter_num == 0:
                ls_kv = callreward.update_all_scores(lambda1, lambda2)
                # np.save('{}/solvd_dict_epoch_{}.npy'.format(config.graph_dir, i), np.array(ls_kv))
                max_rewards_re = callreward.update_scores(
                    max_rewards, lambda1, lambda2)
                rewards_batches_re = callreward.update_scores(
                    rewards_batches, lambda1, lambda2)
                reward_max_per_batch_re = callreward.update_scores(
                    reward_max_per_batch, lambda1, lambda2)

                graph_int, score_min, cyc_min = np.int32(
                    ls_kv[0][0]), ls_kv[0][1][1], ls_kv[0][1][-1]

                if cyc_min < 1e-5:
                    lambda1_upper = score_min
                lambda1 = min(lambda1 + lambda1_update_add, lambda1_upper)
                lambda2 = min(lambda2 * lambda2_update_mul, lambda2_upper)

                graph_batch = convert_graph_int_to_adj_mat(graph_int)

                if reg_type == 'LR':
                    graph_batch_pruned = np.array(
                        graph_prunned_by_coef(graph_batch,
                                              training_set.inputdata))
                elif reg_type == 'QR':
                    graph_batch_pruned = np.array(
                        graph_prunned_by_coef_2nd(graph_batch,
                                                  training_set.inputdata))
                elif reg_type == 'GPR':
                    # The R codes of CAM pruning operates the graph form that (i,j)=1 indicates i-th node-> j-th node
                    # so we need to do a tranpose on the input graph and another tranpose on the output graph
                    graph_batch_pruned = np.transpose(
                        pruning_cam(training_set.inputdata,
                                    np.array(graph_batch).T))

                # estimate accuracy
                # FIXME graph_batch or graph_batch_pruned
                # FIXME .T should be adj mat?
                acc_est = count_accuracy(training_set.true_graph,
                                         graph_batch.T)
                acc_est2 = count_accuracy(training_set.true_graph,
                                          graph_batch_pruned.T)

                # TODO return the adj matrix
                fdr, tpr, fpr, shd, nnz = acc_est['fdr'], acc_est['tpr'], acc_est['fpr'], acc_est['shd'], \
                                          acc_est['pred_size']
                fdr2, tpr2, fpr2, shd2, nnz2 = acc_est2['fdr'], acc_est2['tpr'], acc_est2['fpr'], acc_est2['shd'], \
                                               acc_est2['pred_size']

                accuracy_res.append((fdr, tpr, fpr, shd, nnz))
                accuracy_res_pruned.append((fdr2, tpr2, fpr2, shd2, nnz2))

                print('fdr:', fdr, 'tpr:', tpr, 'fpr:', fpr, 'shd:', shd)
                result = acc_est
        return graph_batch.T