Exemplo n.º 1
0
def main():
    # Get arguments parsed
    args = get_args()

    # Setup for logging
    output_dir = 'output/{}'.format(datetime.now(timezone('Canada/Central')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3])
    create_dir(output_dir)
    LogHelper.setup(log_path='{}/training.log'.format(output_dir), level_str='INFO')
    _logger = logging.getLogger(__name__)

    # Save the configuration for logging purpose
    save_yaml_config(args, path='{}/config.yaml'.format(output_dir))

    # Reproducibility
    set_seed(args.seed)

    # Get dataset
    dataset = SyntheticDataset(args.n, args.d, args.graph_type, args.degree, args.sem_type,
                               args.noise_scale, args.dataset_type)
    _logger.info('Finished generating dataset')

    model = NoTears(args.n, args.d, args.seed, args.l1_lambda, args.use_float64)
    model.print_summary(print_func=model.logger.info)

    trainer = ALTrainer(args.init_rho, args.rho_max, args.h_factor, args.rho_multiply,
                        args.init_iter, args.learning_rate, args.h_tol)
    W_est = trainer.train(model, dataset.X, dataset.W, args.graph_thres,
                          args.max_iter, args.iter_step, output_dir)
    _logger.info('Finished training model')

    # Save raw estimated graph, ground truth and observational data after training
    np.save('{}/true_graph.npy'.format(output_dir), dataset.W)
    np.save('{}/X.npy'.format(output_dir), dataset.X)
    np.save('{}/final_raw_estimated_graph.npy'.format(output_dir), W_est)

    # Plot raw estimated graph
    plot_estimated_graph(W_est, dataset.W,
                         save_name='{}/raw_estimated_graph.png'.format(output_dir))

    _logger.info('Thresholding.')
    # Plot thresholded estimated graph
    W_est[np.abs(W_est) < args.graph_thres] = 0    # Thresholding
    plot_estimated_graph(W_est, dataset.W,
                         save_name='{}/thresholded_estimated_graph.png'.format(output_dir))
    results_thresholded = count_accuracy(dataset.W, W_est)
    _logger.info('Results after thresholding by {}: {}'.format(args.graph_thres, results_thresholded))
Exemplo n.º 2
0
    def train_callback(self, epoch, loss, mse, h, W_true, W_est, graph_thres,
                       output_dir):
        # Evaluate the learned W in each iteration after thresholding
        W_thresholded = np.copy(W_est)
        W_thresholded[np.abs(W_thresholded) < graph_thres] = 0
        results_thresholded = count_accuracy(W_true, W_thresholded)

        self._logger.info(
            '[Iter {}] loss {:.3E}, mse {:.3E}, acyclic {:.3E}, shd {}, tpr {:.3f}, fdr {:.3f}, pred_size {}'
            .format(epoch, loss, mse, h, results_thresholded['shd'],
                    results_thresholded['tpr'], results_thresholded['fdr'],
                    results_thresholded['pred_size']))

        # Save the raw estimated graph in each iteration
        create_dir('{}/raw_estimated_graph'.format(output_dir))
        np.save(
            '{}/raw_estimated_graph/graph_iteration_{}.npy'.format(
                output_dir, epoch), W_est)
Exemplo n.º 3
0
def main():
    # Get arguments parsed
    args = get_args()

    # Setup for logging
    output_dir = 'output/{}'.format(
        datetime.now(
            timezone('Asia/Shanghai')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3])
    create_dir(output_dir)
    LogHelper.setup(log_path='{}/training.log'.format(output_dir),
                    level_str='INFO')
    _logger = logging.getLogger(__name__)

    # Save the configuration for logging purpose
    save_yaml_config(args, path='{}/config.yaml'.format(output_dir))

    # Reproducibility
    set_seed(args.seed)

    # Get dataset
    dataset = RealDataset(args.batch_size)
    _logger.info('Finished generating dataset')

    device = get_device()
    model = VAE(args.z_dim, args.num_hidden, args.input_dim, device)

    trainer = Trainer(args.batch_size, args.num_epochs, args.learning_rate)

    trainer.train_model(model=model,
                        dataset=dataset,
                        output_dir=output_dir,
                        device=device,
                        input_dim=args.input_dim)

    _logger.info('Finished training model')

    # Visualizations
    samples = sample_vae(model, args.z_dim, device)
    plot_samples(samples)

    plot_reconstructions(model, dataset, device)

    _logger.info('All Finished!')
Exemplo n.º 4
0
def main():
    # Get arguments parsed
    args = get_args()

    # Setup for logging
    output_dir = 'output/{}'.format(
        datetime.now(
            timezone('Asia/Hong_Kong')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3])
    create_dir(output_dir)
    LogHelper.setup(log_path='{}/training.log'.format(output_dir),
                    level_str='INFO')
    _logger = logging.getLogger(__name__)

    # Save the configuration for logging purpose
    save_yaml_config(args, path='{}/config.yaml'.format(output_dir))

    # Reproducibility
    set_seed(args.seed)

    # Get dataset
    dataset = SyntheticDataset(args.n, args.d, args.graph_type, args.degree,
                               args.sem_type, args.noise_scale,
                               args.dataset_type, args.x_dim)
    _logger.info('Finished generating dataset')

    model = GAE(args.n, args.d, args.x_dim, args.seed, args.num_encoder_layers,
                args.num_decoder_layers, args.hidden_size, args.latent_dim,
                args.l1_graph_penalty, args.use_float64)
    model.print_summary(print_func=model.logger.info)

    trainer = ALTrainer(args.init_rho, args.rho_thres, args.h_thres,
                        args.rho_multiply, args.init_iter, args.learning_rate,
                        args.h_tol, args.early_stopping,
                        args.early_stopping_thres)
    W_est = trainer.train(model, dataset.X, dataset.W, args.graph_thres,
                          args.max_iter, args.iter_step, output_dir)
    _logger.info('Finished training model')

    # Save raw recovered graph, ground truth and observational data after training
    np.save('{}/true_graph.npy'.format(output_dir), dataset.W)
    np.save('{}/observational_data.npy'.format(output_dir), dataset.X)
    np.save('{}/final_raw_recovered_graph.npy'.format(output_dir), W_est)

    # Plot raw recovered graph
    plot_recovered_graph(
        W_est,
        dataset.W,
        save_name='{}/raw_recovered_graph.png'.format(output_dir))

    _logger.info('Filter by constant threshold')
    W_est = W_est / np.max(np.abs(W_est))  # Normalize

    # Plot thresholded recovered graph
    W_est[np.abs(W_est) < args.graph_thres] = 0  # Thresholding
    plot_recovered_graph(
        W_est,
        dataset.W,
        save_name='{}/thresholded_recovered_graph.png'.format(output_dir))
    results_thresholded = count_accuracy(dataset.W, W_est)
    _logger.info('Results after thresholding by {}: {}'.format(
        args.graph_thres, results_thresholded))
Exemplo n.º 5
0
def synthetic():

    np.set_printoptions(precision=3)

    # Get arguments parsed
    args = get_args()

    # Setup for logging
    output_dir = 'output/{}'.format(
        datetime.now(
            timezone('Asia/Shanghai')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3])
    create_dir(output_dir)
    LogHelper.setup(log_path='{}/training.log'.format(output_dir),
                    level_str='INFO')
    _logger = logging.getLogger(__name__)

    # Save the configuration for logging purpose
    save_yaml_config(args, path='{}/config.yaml'.format(output_dir))

    # Reproducibility
    set_seed(args.seed)

    # Get dataset
    dataset = SyntheticDataset(args.num_X, args.num_Z, args.num_samples,
                               args.max_lag)
    # Save dataset
    dataset.save_dataset(output_dir=output_dir)
    _logger.info('Finished generating dataset')

    # Look at data
    _logger.info('The shape of observed data: {}'.format(dataset.X.shape))
    plot_timeseries(dataset.X[-150:],
                    'X',
                    display_mode=False,
                    save_name=output_dir + '/timeseries_X.png')
    plot_timeseries(dataset.Z[-150:],
                    'Z',
                    display_mode=False,
                    save_name=output_dir + '/timeseries_Z.png')

    # Init model
    model = TimeLatent(args.num_X, args.max_lag, args.num_samples, args.device,
                       args.prior_rho_A, args.prior_sigma_W, args.temperature,
                       args.sigma_Z, args.sigma_X)

    trainer = Trainer(args.learning_rate, args.num_iterations, args.num_output)

    trainer.train_model(model=model,
                        X=torch.tensor(dataset.X,
                                       dtype=torch.float32,
                                       device=args.device),
                        output_dir=output_dir)

    plot_losses(trainer.train_losses,
                display_mode=False,
                save_name=output_dir + '/loss.png')

    # Save result
    trainer.log_and_save_intermediate_outputs()

    _logger.info('Finished training model')

    # Calculate performance

    estimate_A = model.posterior_A.probs[:, :args.num_X, :args.num_X].cpu(
    ).data.numpy(
    )  # model.posterior_A.probs is shape with (max_lag,num_X+num_Z,num_X+num_Z)
    groudtruth_A = np.array(
        dataset.groudtruth)  # groudtruth is shape with (max_lag,num_X,num_X)

    Score = AUC_score(estimate_A.T, groudtruth_A.T)
    _logger.info(
        '\n        fpr:{} \n        tpr:{}\n thresholds:{}\n AUC:{}'.format(
            Score['fpr'], Score['tpr'], Score['thresholds'], Score['AUC']))

    plot_ROC_curve(estimate_A.T,
                   groudtruth_A.T,
                   display_mode=False,
                   save_name=output_dir + '/ROC_Curve.png')

    for t in range(0, 11):
        _logger.info('Under threshold:{}'.format(t / 10))
        _logger.info(F1(estimate_A.T, groudtruth_A.T, threshold=t / 10))

    estimate_A_all = model.posterior_A.probs.cpu().data.numpy()

    # Visualizations
    for k in range(args.max_lag):
        # Note that in our implementation, A_ij=1 means j->i, but in the plot_recovered_graph A_ij=1 means i->j, so transpose A
        plot_recovered_graph(estimate_A[k].T,
                             groudtruth_A[k].T,
                             title='Lag = {}'.format(k + 1),
                             display_mode=False,
                             save_name=output_dir + '/A_lag_{}.png'.format(k))
        plot_recovered_graph(estimate_A_all[k].T,
                             dataset.A[k].T,
                             title='Lag = {}'.format(k + 1),
                             display_mode=False,
                             save_name=output_dir +
                             '/All_lag_{}.png'.format(k))

    _logger.info('All Finished!')
Exemplo n.º 6
0
def real():

    np.set_printoptions(precision=3)

    # Get arguments parsed
    args = get_args()

    # Setup for logging
    output_dir = 'output/real_{}'.format(
        datetime.now(
            timezone('Asia/Shanghai')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3])
    create_dir(output_dir)
    LogHelper.setup(log_path='{}/training.log'.format(output_dir),
                    level_str='INFO')
    _logger = logging.getLogger('real')

    # Save the configuration for logging purpose
    save_yaml_config(args, path='{}/config.yaml'.format(output_dir))

    # Reproducibility
    set_seed(args.seed)

    # Get dataset
    dataset = RealDataset()

    # Look at data
    _logger.info('The shape of observed data: {}'.format(dataset.stock.shape))
    plot_timeseries(dataset.stock[-150:],
                    'stock',
                    display_mode=False,
                    save_name=output_dir + '/timeseries_stock.png')

    # Set parameters
    num_samples, num_X = dataset.stock.shape
    temperature = 2.0
    max_lag = 1
    prior_rho_A = 0.7
    prior_sigma_W = 0.05
    sigma_Z = 1.0
    sigma_X = 0.05
    num_iterations = 3000

    # Log the parameters
    _logger.info(
        "num_X:{},max_lag:{},num_samples:{},args.device:{},prior_rho_A:{},prior_sigma_W:{},temperature:{},sigma_Z:{},sigma_X:{},num_iterations:{}"
        .format(num_X, max_lag, num_samples, args.device, prior_rho_A,
                prior_sigma_W, temperature, sigma_Z, sigma_X, num_iterations))
    # Init model
    model = TimeLatent(num_X=num_X,
                       max_lag=max_lag,
                       num_samples=num_samples,
                       device=args.device,
                       prior_rho_A=prior_rho_A,
                       prior_sigma_W=prior_sigma_W,
                       temperature=temperature,
                       sigma_Z=sigma_Z,
                       sigma_X=sigma_X)
    trainer = Trainer(learning_rate=args.learning_rate,
                      num_iterations=num_iterations,
                      num_output=args.num_output)

    trainer.train_model(model=model,
                        X=torch.tensor(dataset.stock,
                                       dtype=torch.float32,
                                       device=args.device),
                        output_dir=output_dir)

    plot_losses(trainer.train_losses,
                display_mode=False,
                save_name=output_dir + '/loss.png')

    # Save result
    trainer.log_and_save_intermediate_outputs()

    _logger.info('Finished training model')

    estimate_A = model.posterior_A.probs.cpu().data.numpy()

    # Visualizations
    for k in range(max_lag):
        # Note that in our implementation, A_ij=1 means j->i, but in the plot_recovered_graph A_ij=1 means i->j, so transpose A
        plot_recovered_graph(estimate_A[k].T,
                             W=None,
                             title='Lag = {}'.format(k + 1),
                             display_mode=False,
                             save_name=output_dir + '/lag_{}.png'.format(k))

    _logger.info('All Finished!')
Exemplo n.º 7
0
 def save(self, model_dir):
     create_dir(model_dir)
     self.saver.save(self.sess, '{}/model'.format(model_dir))
Exemplo n.º 8
0
def main():
    # Setup for output directory and logging
    output_dir = 'output/{}'.format(
        datetime.now(
            timezone('Asia/Hong_Kong')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3])
    create_dir(output_dir)
    LogHelper.setup(log_path='{}/training.log'.format(output_dir),
                    level_str='INFO')
    _logger = logging.getLogger(__name__)
    _logger.info('Python version is {}'.format(platform.python_version()))
    _logger.info('Current commit of code: ___')

    # Get running configuration
    config, _ = get_config()
    config.save_model_path = '{}/model'.format(output_dir)
    # config.restore_model_path = '{}/model'.format(output_dir)
    config.summary_dir = '{}/summary'.format(output_dir)
    config.plot_dir = '{}/plot'.format(output_dir)
    config.graph_dir = '{}/graph'.format(output_dir)

    # Create directory
    create_dir(config.summary_dir)
    create_dir(config.summary_dir)
    create_dir(config.plot_dir)
    create_dir(config.graph_dir)

    # Reproducibility
    set_seed(config.seed)

    # Log the configuration parameters
    _logger.info('Configuration parameters: {}'.format(
        vars(config)))  # Use vars to convert config to dict for logging

    if config.read_data:
        file_path = '{}/data.npy'.format(config.data_path)
        solution_path = '{}/DAG.npy'.format(config.data_path)
        training_set = DataGenerator_read_data(file_path, solution_path,
                                               config.normalize,
                                               config.transpose)
    else:
        raise ValueError("Only support importing data from existing files")

    # set penalty weights
    score_type = config.score_type
    reg_type = config.reg_type

    if config.lambda_flag_default:

        sl, su, strue = BIC_lambdas(training_set.inputdata, None, None,
                                    training_set.true_graph.T, reg_type,
                                    score_type)

        lambda1 = 0
        lambda1_upper = 5
        lambda1_update_add = 1
        lambda2 = 1 / (10**(np.round(config.max_length / 3)))
        lambda2_upper = 0.01
        lambda2_update_mul = 10
        lambda_iter_num = config.lambda_iter_num

        # test initialized score
        _logger.info('Original sl: {}, su: {}, strue: {}'.format(
            sl, su, strue))
        _logger.info('Transfomed sl: {}, su: {}, lambda2: {}, true: {}'.format(
            sl, su, lambda2, (strue - sl) / (su - sl) * lambda1_upper))

    else:
        # test choices for the case with manually provided bounds
        # not fully tested

        sl = config.score_lower
        su = config.score_upper
        if config.score_bd_tight:
            lambda1 = 2
            lambda1_upper = 2
        else:
            lambda1 = 0
            lambda1_upper = 5
            lambda1_update_add = 1
        lambda2 = 1 / (10**(np.round(config.max_length / 3)))
        lambda2_upper = 0.01
        lambda2_update_mul = config.lambda2_update
        lambda_iter_num = config.lambda_iter_num

    # actor
    actor = Actor(config)

    callreward = get_Reward(actor.batch_size, config.max_length,
                            actor.input_dimension, training_set.inputdata, sl,
                            su, lambda1_upper, score_type, reg_type,
                            config.l1_graph_reg, False)

    _logger.info(
        'Finished creating training dataset, actor model and reward class')

    # Saver to save & restore all the variables.
    variables_to_save = [
        v for v in tf.global_variables() if 'Adam' not in v.name
    ]
    saver = tf.train.Saver(var_list=variables_to_save,
                           keep_checkpoint_every_n_hours=1.0)

    _logger.info('Starting session...')
    sess_config = tf.ConfigProto(log_device_placement=False)
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        # Run initialize op
        sess.run(tf.global_variables_initializer())

        # Test tensor shape
        _logger.info('Shape of actor.input: {}'.format(
            sess.run(tf.shape(actor.input_))))
        # _logger.info('training_set.true_graph: {}'.format(training_set.true_graph))
        # _logger.info('training_set.b: {}'.format(training_set.b))

        # Initialize useful variables
        rewards_avg_baseline = []
        rewards_batches = []
        reward_max_per_batch = []

        lambda1s = []
        lambda2s = []

        graphss = []
        probsss = []
        max_rewards = []
        max_reward = float('-inf')
        image_count = 0

        accuracy_res = []
        accuracy_res_pruned = []

        max_reward_score_cyc = (lambda1_upper + 1, 0)

        # Summary writer
        writer = tf.summary.FileWriter(config.summary_dir, sess.graph)

        _logger.info('Starting training.')

        for i in (range(1, config.nb_epoch + 1)):

            if config.verbose:
                _logger.info('Start training for {}-th epoch'.format(i))

            input_batch = training_set.train_batch(actor.batch_size,
                                                   actor.max_length,
                                                   actor.input_dimension)
            graphs_feed = sess.run(actor.graphs,
                                   feed_dict={actor.input_: input_batch})
            reward_feed = callreward.cal_rewards(
                graphs_feed, lambda1,
                lambda2)  #BTBT 得到生成的一个graph batch中每个graph的reward

            # max reward, max reward per batch
            max_reward = -callreward.update_scores(
                [max_reward_score_cyc], lambda1, lambda2
            )[0]  #BTBT 每跑batch我们都会用更新后的lambda1/2对截至到目前找到的max_reward对应的max_reward_score_cyc再算一次rewaerd作为新的max_reward
            max_reward_batch = float('inf')
            max_reward_batch_score_cyc = (0, 0)

            for reward_, score_, cyc_ in reward_feed:
                if reward_ < max_reward_batch:
                    max_reward_batch = reward_
                    max_reward_batch_score_cyc = (score_, cyc_)

            max_reward_batch = -max_reward_batch

            if max_reward < max_reward_batch:  #BTBT 若该batch的所有graph中有一个reward大于max_reward,该reward就作为max_reward
                max_reward = max_reward_batch
                max_reward_score_cyc = max_reward_batch_score_cyc

            # for average reward per batch
            reward_batch_score_cyc = np.mean(reward_feed[:, 1:], axis=0)

            if config.verbose:
                _logger.info(
                    'Finish calculating reward for current batch of graph')

            # Get feed dict
            feed = {
                actor.input_: input_batch,
                actor.reward_: -reward_feed[:, 0],
                actor.graphs_: graphs_feed
            }

            summary, base_op, score_test, probs, graph_batch, \
                reward_batch, reward_avg_baseline, train_step1, train_step2 = sess.run([
                    actor.merged,
                    actor.base_op,
                    actor.test_scores,
                    actor.log_softmax,
                    actor.graph_batch,
                    actor.reward_batch,
                    actor.avg_baseline,
                    actor.train_step1,
                    actor.train_step2], feed_dict=feed)

            if config.verbose:
                _logger.info(
                    'Finish updating actor and critic network using reward calculated'
                )

            lambda1s.append(lambda1)
            lambda2s.append(lambda2)

            rewards_avg_baseline.append(reward_avg_baseline)
            rewards_batches.append(reward_batch_score_cyc)
            reward_max_per_batch.append(max_reward_batch_score_cyc)

            graphss.append(graph_batch)
            probsss.append(probs)
            max_rewards.append(max_reward_score_cyc)

            # logging
            if i == 1 or i % 500 == 0:
                if i >= 500:
                    writer.add_summary(summary, i)

                _logger.info(
                    '[iter {}] reward_batch: {}, max_reward: {}, max_reward_batch: {}'
                    .format(i, reward_batch, max_reward, max_reward_batch))
                # other logger info; uncomment if you want to check
                # _logger.info('graph_batch_avg: {}'.format(graph_batch))
                # _logger.info('graph true: {}'.format(training_set.true_graph))
                # _logger.info('graph weights true: {}'.format(training_set.b))
                # _logger.info('=====================================')

                plt.figure(1)
                plt.plot(rewards_batches, label='reward per batch')
                plt.plot(max_rewards, label='max reward')
                plt.legend()
                plt.savefig('{}/reward_batch_average.png'.format(
                    config.plot_dir))
                plt.close()

                image_count += 1
                # this draw the average graph per batch.
                # can be modified to draw the graph (with or w/o pruning) that has the best reward
                fig = plt.figure(2)
                fig.suptitle('Iteration: {}'.format(i))
                ax = fig.add_subplot(1, 2, 1)
                ax.set_title('recovered_graph')
                ax.imshow(np.around(graph_batch.T).astype(int),
                          cmap=plt.cm.gray)
                ax = fig.add_subplot(1, 2, 2)
                ax.set_title('ground truth')
                ax.imshow(training_set.true_graph, cmap=plt.cm.gray)
                plt.savefig('{}/recovered_graph_iteration_{}.png'.format(
                    config.plot_dir, image_count))
                plt.close()

            # update lambda1, lamda2
            if (i + 1) % lambda_iter_num == 0:
                ls_kv = callreward.update_all_scores(lambda1, lambda2)
                # np.save('{}/solvd_dict_epoch_{}.npy'.format(config.graph_dir, i), np.array(ls_kv))
                max_rewards_re = callreward.update_scores(
                    max_rewards, lambda1, lambda2)
                rewards_batches_re = callreward.update_scores(
                    rewards_batches, lambda1, lambda2)
                reward_max_per_batch_re = callreward.update_scores(
                    reward_max_per_batch, lambda1, lambda2)

                # saved somewhat more detailed logging info
                np.save('{}/solvd_dict.npy'.format(config.graph_dir),
                        np.array(ls_kv))
                pd.DataFrame(np.array(max_rewards_re)).to_csv(
                    '{}/max_rewards.csv'.format(output_dir))
                pd.DataFrame(rewards_batches_re).to_csv(
                    '{}/rewards_batch.csv'.format(output_dir))
                pd.DataFrame(reward_max_per_batch_re).to_csv(
                    '{}/reward_max_batch.csv'.format(output_dir))
                pd.DataFrame(lambda1s).to_csv(
                    '{}/lambda1s.csv'.format(output_dir))
                pd.DataFrame(lambda2s).to_csv(
                    '{}/lambda2s.csv'.format(output_dir))

                graph_int, score_min, cyc_min = np.int32(
                    ls_kv[0][0]), ls_kv[0][1][1], ls_kv[0][1][-1]

                if cyc_min < 1e-5:
                    lambda1_upper = score_min
                lambda1 = min(lambda1 + lambda1_update_add, lambda1_upper)
                lambda2 = min(lambda2 * lambda2_update_mul, lambda2_upper)
                _logger.info(
                    '[iter {}] lambda1 {}, upper {}, lambda2 {}, upper {}, score_min {}, cyc_min {}'
                    .format(i + 1, lambda1, lambda1_upper, lambda2,
                            lambda2_upper, score_min, cyc_min))

                graph_batch = convert_graph_int_to_adj_mat(graph_int)
                #BTBT ??? graph prune 怎么做?
                if reg_type == 'LR':
                    graph_batch_pruned = np.array(
                        graph_prunned_by_coef(graph_batch,
                                              training_set.inputdata))
                elif reg_type == 'QR':
                    graph_batch_pruned = np.array(
                        graph_prunned_by_coef_2nd(graph_batch,
                                                  training_set.inputdata))
                elif reg_type == 'GPR':
                    # The R codes of CAM pruning operates the graph form that (i,j)=1 indicates i-th node-> j-th node
                    # so we need to do a tranpose on the input graph and another tranpose on the output graph
                    graph_batch_pruned = np.transpose(
                        pruning_cam(training_set.inputdata,
                                    np.array(graph_batch).T))

                # estimate accuracy
                acc_est = count_accuracy(training_set.true_graph,
                                         graph_batch.T)
                acc_est2 = count_accuracy(training_set.true_graph,
                                          graph_batch_pruned.T)

                fdr, tpr, fpr, shd, nnz = acc_est['fdr'], acc_est['tpr'], acc_est['fpr'], acc_est['shd'], \
                                          acc_est['pred_size']
                fdr2, tpr2, fpr2, shd2, nnz2 = acc_est2['fdr'], acc_est2['tpr'], acc_est2['fpr'], acc_est2['shd'], \
                                               acc_est2['pred_size']

                accuracy_res.append((fdr, tpr, fpr, shd, nnz))
                accuracy_res_pruned.append((fdr2, tpr2, fpr2, shd2, nnz2))

                np.save('{}/accuracy_res.npy'.format(output_dir),
                        np.array(accuracy_res))
                np.save('{}/accuracy_res2.npy'.format(output_dir),
                        np.array(accuracy_res_pruned))

                _logger.info(
                    'before pruning: fdr {}, tpr {}, fpr {}, shd {}, nnz {}'.
                    format(fdr, tpr, fpr, shd, nnz))
                _logger.info(
                    'after  pruning: fdr {}, tpr {}, fpr {}, shd {}, nnz {}'.
                    format(fdr2, tpr2, fpr2, shd2, nnz2))

            # Save the variables to disk
            if i % max(1, int(config.nb_epoch / 5)) == 0 and i != 0:
                curr_model_path = saver.save(sess,
                                             '{}/tmp.ckpt'.format(
                                                 config.save_model_path),
                                             global_step=i)
                _logger.info('Model saved in file: {}'.format(curr_model_path))

        _logger.info('Training COMPLETED !')
        saver.save(sess, '{}/actor.ckpt'.format(config.save_model_path))
Exemplo n.º 9
0
    def train(self, model, X, W, graph_thres, max_iter, iter_step, output_dir):
        """
        model object should contain the several class member:
        - sess
        - train_op
        - loss
        - mse_loss
        - h
        - W_prime
        - X
        - rho
        - alpha
        - lr
        """
        # Create directory to save the raw recovered graph in each iteration
        create_dir('{}/raw_recovered_graph'.format(output_dir))

        model.sess.run(tf.global_variables_initializer())
        rho, alpha, h, h_new = self.init_rho, 0.0, np.inf, np.inf
        prev_W_est, prev_mse = None, float('inf')

        self._logger.info(
            'Started training for {} iterations'.format(max_iter))
        for i in range(1, max_iter + 1):
            while rho < self.rho_thres:
                self._logger.info('rho {:.3E}, alpha {:.3E}'.format(
                    rho, alpha))
                loss_new, mse_new, h_new, W_new = self.train_step(
                    model, iter_step, X, rho, alpha)
                if h_new > self.h_thres * h:
                    rho *= self.rho_multiply
                else:
                    break

            if self.early_stopping:
                if mse_new / prev_mse > self.early_stopping_thres and h_new <= 1e-7:
                    # MSE increases too much, revert back to original graph and perform early stopping
                    # Only perform this early stopping when h_new is sufficiently small
                    # (i.e., at least smaller than 1e-7)
                    return prev_W_est
                else:
                    prev_W_est = W_new
                    prev_mse = mse_new

            # Intermediate outputs
            self.log_and_save_intermediate_outputs(i, W, W_new, graph_thres,
                                                   loss_new, mse_new, h_new,
                                                   output_dir)

            W_est, h = W_new, h_new
            alpha += rho * h
            if h <= self.h_tol and i > self.init_iter:
                self._logger.info(
                    'Early stopping at {}-th iteration'.format(i))
                break

        # Save model
        model_dir = '{}/model/'.format(output_dir)
        model.save(model_dir)
        self._logger.info('Model saved to {}'.format(model_dir))

        return W_est