Ejemplo n.º 1
0
def get_train_loader(batch_size,
                     mu,
                     n_iters_per_epoch,
                     L,
                     root='dataset',
                     seed=0):
    data_x, label_x, data_u, label_u = load_data_train(L=L,
                                                       dspth=root,
                                                       seed=seed)

    ds_x = Cifar10(data=data_x, labels=label_x, is_train=True)
    sampler_x = RandomSampler(ds_x,
                              replacement=True,
                              num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)
    dl_x = torch.utils.data.DataLoader(ds_x,
                                       batch_sampler=batch_sampler_x,
                                       num_workers=1,
                                       pin_memory=True)
    ds_u = Cifar10(data=data_u, labels=label_u, is_train=True)
    sampler_u = RandomSampler(ds_u,
                              replacement=True,
                              num_samples=mu * n_iters_per_epoch * batch_size)
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(ds_u,
                                       batch_sampler=batch_sampler_u,
                                       num_workers=2,
                                       pin_memory=True)
    return dl_x, dl_u
Ejemplo n.º 2
0
Archivo: cifar.py Proyecto: MLDL/FROST
def get_train_loader(batch_size, mu, mu_c, n_iters_per_epoch, L, root='dataset', seed=0, name=None):
    if name == None:
        name = "dataset/seeds/size"+str(L)+"seed"+str(seed)+".npy"
    data_x, label_x, data_u, label_u, data_all, label_all = load_data_train(L=L, dspth=root, seed=seed, name=name)
    
    ds_x = Cifar10(
        data=data_x,
        labels=label_x,
        is_train=True
    )
    sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)
    dl_x = torch.utils.data.DataLoader(
        ds_x,
        batch_sampler=batch_sampler_x,
        num_workers=1,
        pin_memory=True
    )
    
    ds_u = Cifar10(
        data=data_u,
        labels=label_u,
        is_train=True
    )
    sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(
        ds_u,
        batch_sampler=batch_sampler_u,
        num_workers=2,
        pin_memory=True
    )
    
    ds_all = Cifar10(
        data=data_all,
        labels=label_all,
        is_train=True
    )
    #sampler_all = RandomSampler(ds_all, replacement=True, num_samples= mu_c * n_iters_per_epoch * batch_size)
    sampler_all = SequentialSampler(ds_all)
    batch_sampler_all = BatchSampler(sampler_all, batch_size * mu_c, drop_last=True)
    dl_all = torch.utils.data.DataLoader(
        ds_all,
        batch_sampler=batch_sampler_all,
        num_workers=2,
        pin_memory=True
    )
    return dl_x, dl_u, dl_all
Ejemplo n.º 3
0
    def __init__(self,
                 dataset,
                 batch_size=1,
                 shuffle=False,
                 sampler=None,
                 batch_sampler=None,
                 num_workers=0,
                 collate_fn=default_collate,
                 pin_memory=False,
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
def model_forward(model, data, sampler=None, is_train=True):
    if sampler is None:
        sampler = RandomSampler(data, FLAGS.batch_size, FLAGS.sample_induced)
    if FLAGS.lower_level_layers and FLAGS.higher_level_layers:
        if "model_init" in FLAGS.init_embds and is_train:
            _get_initial_embd(data, model)
            data.dataset.init_interaction_graph_embds(device=FLAGS.device)
            model.init_x = data.dataset.interaction_combo_nxgraph.init_x.cpu(
            ).detach().numpy()

        batch_gids, sampled_gids, subgraph = sampler.sample_next_training_batch(
        )
        batch_data = BatchData(
            batch_gids,
            data.dataset,
            is_train=is_train,
            sampled_gids=sampled_gids,
            enforce_negative_sampling=FLAGS.enforce_negative_sampling,
            unique_graphs=FLAGS.batch_unique_graphs,
            subgraph=subgraph)

        if FLAGS.pair_interaction:
            model.use_layers = 'lower_layers'
            model(batch_data)
        model.use_layers = 'higher_layers'
    else:
        batch_gids, sampled_gids, subgraph = sampler.sample_next_training_batch(
        )
        batch_data = BatchData(
            batch_gids,
            data.dataset,
            is_train=is_train,
            sampled_gids=sampled_gids,
            enforce_negative_sampling=FLAGS.enforce_negative_sampling,
            unique_graphs=FLAGS.batch_unique_graphs,
            subgraph=subgraph)
    return batch_data
Ejemplo n.º 5
0
from sampler import UncertaintySampler, RandomSampler


def label(sampler):
    cont = "T"
    while (cont == "T"):
        # grab k uncertain samples
        k_indices = sampler.sample(20)

        # get the labelling process going
        sampler.process_k_edus(k_indices)

        # write current state to a file
        sampler.save()

        cont = input("continue? T/F ")

    print("======> END <======")


if __name__ == '__main__':
    sampler = RandomSampler()
    label(sampler)
Ejemplo n.º 6
0
def run(gen_tree_func,
        msg_ids_path,
        root_sampling_method='random',
        interaction_path=os.path.join(CURDIR, 'data/enron.json'),
        lda_model_path=os.path.join(CURDIR, 'models/model-4-50.lda'),
        corpus_dict_path=os.path.join(CURDIR, 'models/dictionary.pkl'),
        meta_graph_pkl_path_prefix=os.path.join(CURDIR, 'data/enron'),
        meta_graph_pkl_suffix='',
        cand_tree_number=None,  # higher priority than percentage
        cand_tree_percent=0.1,
        result_pkl_path_prefix=os.path.join(CURDIR, 'tmp/results'),
        result_suffix='',
        all_paths_pkl_prefix='',
        all_paths_pkl_suffix='',
        true_events_path='',
        meta_graph_kws={
            'dist_func': cosine,
            'preprune_secs': timedelta(weeks=4),
            'distance_weights': {'topics': 0.2,
                                 'bow': 0.8},
            # 'timestamp_converter': lambda s: s
        },
        gen_tree_kws={
            'timespan': timedelta(weeks=4),
            'U': 0.5,
            'dijkstra': False
        },
        convert_time=True,
        roots=None,
        calculate_graph=False,
        given_topics=False,
        print_summary=False,
        should_binarize_dag=False):
    if isinstance(gen_tree_kws['timespan'], timedelta):
        timespan = gen_tree_kws['timespan'].total_seconds()
    else:
        timespan = gen_tree_kws['timespan']
    U = gen_tree_kws['U']
        
    if interaction_path.endswith(".json"):
        try:
            interactions = json.load(open(interaction_path))
        except ValueError:
            interactions = load_json_by_line(interaction_path)
    elif interaction_path.endswith(".pkl"):
        interactions = pickle.load(open(interaction_path))
    else:
        raise ValueError("invalid path extension: {}".format(interaction_path))


    logger.info('loading lda from {}'.format(lda_model_path))
    if not given_topics:
        lda_model = gensim.models.wrappers.LdaMallet.load(
            os.path.join(CURDIR, lda_model_path)
        )
        dictionary = gensim.corpora.dictionary.Dictionary.load(
            os.path.join(CURDIR, corpus_dict_path)
        )
    else:
        lda_model = None
        dictionary = None

    meta_graph_pkl_path = "{}--{}{}.pkl".format(
        meta_graph_pkl_path_prefix,
        experiment_signature(**meta_graph_kws),
        meta_graph_pkl_suffix
    )
    logger.info('meta_graph_pkl_path: {}'.format(meta_graph_pkl_path))

    if calculate_graph or not os.path.exists(meta_graph_pkl_path):
        # we want to calculate the graph or
        # it's not there so we have to
        logger.info('calculating meta_graph...')
        meta_graph_kws_copied = copy.deepcopy(meta_graph_kws)
        with open(msg_ids_path) as f:
            msg_ids = [l.strip() for l in f]

        if isinstance(meta_graph_kws_copied['preprune_secs'], timedelta):
            meta_graph_kws_copied['preprune_secs'] = meta_graph_kws['preprune_secs'].total_seconds()
        g = IU.get_topic_meta_graph(
            interactions,
            msg_ids=msg_ids,
            lda_model=lda_model,
            dictionary=dictionary,
            undirected=False,  # deprecated
            given_topics=given_topics,
            decompose_interactions=False,
            convert_time=convert_time,
            **meta_graph_kws_copied
        )

        logger.info('pickling...')
        nx.write_gpickle(
            IU.compactize_meta_graph(g, map_nodes=False),
            meta_graph_pkl_path
        )
    else:
        logger.info('loading pickle...')
        g = nx.read_gpickle(meta_graph_pkl_path)
        
    if print_summary:
        logger.debug(get_summary(g))

    assert g.number_of_nodes() > 0, 'empty graph!'

    if not roots:
        cand_tree_number, cand_tree_percent = get_number_and_percentage(
            g.number_of_nodes(),
            cand_tree_number,
            cand_tree_percent
        )
        if root_sampling_method == 'random':
            root_sampler = RandomSampler(g, timespan)
        elif root_sampling_method == 'upperbound':
            root_sampler = UBSampler(g, U, timespan)
        else:
            logger.info('init AdaptiveSampler...')
            root_sampler = AdaptiveSampler(g, U, timespan)
    else:
        logger.info('Roots given')
        cand_tree_number = len(roots)
        root_sampler = DeterministicSampler(g, roots, timespan)
    
    logger.info('#roots: {}'.format(cand_tree_number))
    logger.info('#cand_tree_percent: {}'.format(
        cand_tree_number / float(g.number_of_nodes()))
    )

    trees = []
    dags = []
    for i in xrange(cand_tree_number):
        logger.info("sampling root...")
        try:
            root, dag = root_sampler.take()
        except IndexError:
            logger.warn('not enough root to take, terminate')
            break
        dags.append(dag)
        
        
        start = datetime.now()
        tree = calc_tree(i, root, dag, U,
                         gen_tree_func,
                         gen_tree_kws,
                         print_summary,
                         should_binarize_dag=should_binarize_dag)
        tree.graph['calculation_time'] = (datetime.now() - start).total_seconds()
        
        trees.append(tree)

        logger.info("updating sampler states...")
        root_sampler.update(root, tree)

    def make_detailed_path(prefix, suffix):
        return "{}--{}----{}----{}{}.pkl".format(
            prefix,
            experiment_signature(**gen_tree_kws),
            experiment_signature(**meta_graph_kws),
            experiment_signature(
                cand_tree_percent=cand_tree_percent,
                root_sampling=root_sampling_method
            ),
            suffix
        )
    result_pkl_path = make_detailed_path(result_pkl_path_prefix,
                                         result_suffix)

    logger.info('result_pkl_path: {}'.format(result_pkl_path))
    pickle.dump(trees,
                open(result_pkl_path, 'w'),
                protocol=pickle.HIGHEST_PROTOCOL)
    if False:
        # for debugging purpose
        pickle.dump(dags,
                    open(result_pkl_path+'.dag', 'w'),
                    protocol=pickle.HIGHEST_PROTOCOL)
    
    all_paths_pkl_path = make_detailed_path(all_paths_pkl_prefix,
                                            all_paths_pkl_suffix)
    logger.info('Dumping the paths info to {}'.format(all_paths_pkl_path))
    paths_dict = {'interactions': interaction_path,
                  'meta_graph': meta_graph_pkl_path,
                  'result': result_pkl_path,
                  'true_events': true_events_path,
                  'self': all_paths_pkl_path
    }
    pickle.dump(
        paths_dict,
        open(all_paths_pkl_path, 'w')
    )
    return paths_dict
def _train(num_iters_total,
           train_data,
           val_data,
           train_val_links,
           model,
           optimizer,
           saver,
           fold_num,
           retry_num=0):
    fold_str = '' if fold_num is None else 'Fold_{}_'.format(fold_num)
    fold_str = fold_str + 'retry_{}_'.format(
        retry_num) if retry_num > 0 else fold_str
    if fold_str == '':
        print("here")
    epoch_timer = Timer()
    total_loss = 0
    curr_num_iters = 0
    val_results = {}
    if FLAGS.sampler == "neighbor_sampler":
        sampler = NeighborSampler(train_data, FLAGS.num_neighbors_sample,
                                  FLAGS.batch_size)
        estimated_iters_per_epoch = ceil(
            (len(train_data.dataset.gs_map) / FLAGS.batch_size))
    elif FLAGS.sampler == "random_sampler":
        sampler = RandomSampler(train_data, FLAGS.batch_size,
                                FLAGS.sample_induced)
        estimated_iters_per_epoch = ceil(
            (len(train_data.dataset.train_pairs) / FLAGS.batch_size))
    else:
        sampler = EverythingSampler(train_data)
        estimated_iters_per_epoch = 1

    moving_avg = MovingAverage(FLAGS.validation_window_size)
    iters_per_validation = FLAGS.iters_per_validation \
        if FLAGS.iters_per_validation != -1 else estimated_iters_per_epoch

    for iter in range(FLAGS.num_iters):
        model.train()
        model.zero_grad()
        batch_data = model_forward(model, train_data, sampler=sampler)
        loss = _train_iter(batch_data, model, optimizer)
        batch_data.restore_interaction_nxgraph()
        total_loss += loss
        num_iters_total_limit = FLAGS.num_iters
        curr_num_iters += 1
        if num_iters_total_limit is not None and \
                num_iters_total == num_iters_total_limit:
            break
        if iter % FLAGS.print_every_iters == 0:
            saver.log_tvt_info("{}Iter {:04d}, Loss: {:.7f}".format(
                fold_str, iter + 1, loss))
            if COMET_EXPERIMENT:
                COMET_EXPERIMENT.log_metric("{}loss".format(fold_str), loss,
                                            iter + 1)
        if (iter + 1) % iters_per_validation == 0:
            eval_res, supplement = validation(
                model,
                val_data,
                train_val_links,
                saver,
                max_num_examples=FLAGS.max_eval_pairs)
            epoch = iter / estimated_iters_per_epoch
            saver.log_tvt_info('{}Estimated Epoch: {:05f}, Loss: {:.7f} '
                               '({} iters)\t\t{}\n Val Result: {}'.format(
                                   fold_str, epoch,
                                   eval_res["Loss"], curr_num_iters,
                                   epoch_timer.time_and_clear(), eval_res))
            if COMET_EXPERIMENT:
                COMET_EXPERIMENT.log_metrics(
                    eval_res,
                    prefix="{}validation".format(fold_str),
                    step=iter + 1)
                COMET_EXPERIMENT.log_histogram_3d(
                    supplement['y_pred'],
                    name="{}y_pred".format(fold_str),
                    step=iter + 1)
                COMET_EXPERIMENT.log_histogram_3d(
                    supplement['y_true'],
                    name='{}y_true'.format(fold_str),
                    step=iter + 1)
                confusion_matrix = supplement.get('confusion_matrix')
                if confusion_matrix is not None:
                    labels = [
                        k for k, v in sorted(
                            batch_data.dataset.interaction_edge_labels.items(),
                            key=lambda item: item[1])
                    ]
                    COMET_EXPERIMENT.log_confusion_matrix(
                        matrix=confusion_matrix, labels=labels, step=iter + 1)
            curr_num_iters = 0
            val_results[iter + 1] = eval_res
            if len(moving_avg.results) == 0 or (
                    eval_res[FLAGS.validation_metric] - 1e-7) > max(
                        moving_avg.results):
                saver.save_trained_model(model, iter + 1)
            moving_avg.add_to_moving_avg(eval_res[FLAGS.validation_metric])
            if moving_avg.stop():
                break
    return val_results
Ejemplo n.º 8
0
def train_and_save(model,
                   trainset,
                   lr,
                   bs,
                   minimization_time,
                   file_state,
                   file_losses,
                   time_delay=0,
                   time_factor=None,
                   losses_dump=None):
    """This function trains a model by model and saves both its state_dict at
    the end and the losses (on a log scale). It takes care of cuda().

     -- model,          the model (the user should call .cuda() before)
     -- trainset,       the dataset
     -- lr,             the learning rate
     -- bs,             the batch size
     -- min_time,       training time
     -- f_state,        where to save the state_dict
     -- f_losses,       where to save the loss evolution
     -- time_factor,    multiplicative factor for log time intervals (to save data)
     -- losses_dump     it can be passed an open() ref, data will be dumped there;
                        it is useful to train several time the same system
        )
    """

    if cuda.is_available(): model.cuda()
    model.train(
    )  # not necessary in this simple model, but I keep it for the sake of generality
    optimizer = optim.SGD(model.parameters(), lr=lr)  # learning rate

    trainloader = DataLoader(
        trainset,  # dataset
        batch_size=bs,  # batch size
        pin_memory=cuda.is_available(),  # speed-up for gpu's
        sampler=RandomSampler(len(trainset))  # no epochs
    )

    if time_factor == None: time_factor = minimization_time**(1.0 / 200)

    next_t = 1.0 * lr  # Times are multiplied by the LR
    batch = 0

    # NOTE: if losses_dump is an open file, use it, regardless of 'file_losses'
    no_file = bool(losses_dump == None)
    if no_file: losses_dump = open(file_losses, 'wb')

    for data, target in load_batch(trainloader, cuda=cuda.is_available()):
        batch += 1
        if batch * lr > minimization_time:  # Times are multiplied by the LR
            break

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target, size_average=True)
        loss.backward()
        optimizer.step()

        # Times are multiplied by the LR
        # Save also the last step!
        if batch * lr > next_t or (batch + 1) * lr > minimization_time:
            # I want to save the average loss on the total training set
            avg_loss = 0
            total_trainloader = DataLoader(
                trainset,
                batch_size=1024,  # I don't need small batches for this
                pin_memory=cuda.is_available(),
                sampler=RandomSampler(len(trainset)))

            for data, target in load_batch(total_trainloader,
                                           cuda=cuda.is_available(),
                                           only_one_epoch=True):
                output = model(data)
                avg_loss += F.nll_loss(output, target,
                                       size_average=False).data[0]

            pickle.dump((time_delay + batch * lr, avg_loss / len(trainset)),
                        losses_dump)
            next_t *= time_factor

    if no_file == None: losses_dump.close()

    state_dict = model.state_dict()  # == losses[-1]['state_dict']
    torch.save(state_dict, file_state)

    return state_dict
Ejemplo n.º 9
0
 def test_random_sampler(self):
     s = RandomSampler(self.g, timespan_secs=3)
     assert_false([s.take()[0] for i in xrange(4)] == range(4))
     assert_equal(0, len(s.nodes))
Ejemplo n.º 10
0
from trainer import Trainer
from models import Encoder, Decoder
from sampler import RandomSampler

N_FEATURES = 100
HID_SIZE = 10
ENCODER_SIZES = [50, 30, 20]
DECODER_SIZES = []
N_ITERS = 10000
BATCH_SIZE = 1024
LR = 10.

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
    sampler = RandomSampler(N_FEATURES)

    encoder = Encoder(N_FEATURES, HID_SIZE, ENCODER_SIZES).to(device)

    decoder = Decoder(HID_SIZE, N_FEATURES, DECODER_SIZES).to(device)

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=LR)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=LR)

    encoder_scheduler = optim.lr_scheduler.StepLR(encoder_optimizer,
                                                  step_size=500,
                                                  gamma=0.9)
    decoder_scheduler = optim.lr_scheduler.StepLR(decoder_optimizer,
                                                  step_size=500,
                                                  gamma=0.9)
def run(
        gen_tree_func,
        msg_ids_path,
        root_sampling_method='random',
        interaction_path=os.path.join(CURDIR, 'data/enron.json'),
        lda_model_path=os.path.join(CURDIR, 'models/model-4-50.lda'),
        corpus_dict_path=os.path.join(CURDIR, 'models/dictionary.pkl'),
        meta_graph_pkl_path_prefix=os.path.join(CURDIR, 'data/enron'),
        meta_graph_pkl_suffix='',
        cand_tree_number=None,  # higher priority than percentage
        cand_tree_percent=0.1,
        result_pkl_path_prefix=os.path.join(CURDIR, 'tmp/results'),
        result_suffix='',
        all_paths_pkl_prefix='',
        all_paths_pkl_suffix='',
        true_events_path='',
        meta_graph_kws={
            'dist_func': cosine,
            'preprune_secs': timedelta(weeks=4),
            'distance_weights': {
                'topics': 0.2,
                'bow': 0.8
            },
            # 'timestamp_converter': lambda s: s
        },
        gen_tree_kws={
            'timespan': timedelta(weeks=4),
            'U': 0.5,
            'dijkstra': False
        },
        convert_time=True,
        roots=None,
        calculate_graph=False,
        given_topics=False,
        print_summary=False,
        should_binarize_dag=False):
    if isinstance(gen_tree_kws['timespan'], timedelta):
        timespan = gen_tree_kws['timespan'].total_seconds()
    else:
        timespan = gen_tree_kws['timespan']
    U = gen_tree_kws['U']

    if interaction_path.endswith(".json"):
        try:
            interactions = json.load(open(interaction_path))
        except ValueError:
            interactions = load_json_by_line(interaction_path)
    elif interaction_path.endswith(".pkl"):
        interactions = pickle.load(open(interaction_path))
    else:
        raise ValueError("invalid path extension: {}".format(interaction_path))

    logger.info('loading lda from {}'.format(lda_model_path))
    if not given_topics:
        lda_model = gensim.models.wrappers.LdaMallet.load(
            os.path.join(CURDIR, lda_model_path))
        dictionary = gensim.corpora.dictionary.Dictionary.load(
            os.path.join(CURDIR, corpus_dict_path))
    else:
        lda_model = None
        dictionary = None

    meta_graph_pkl_path = "{}--{}{}.pkl".format(
        meta_graph_pkl_path_prefix, experiment_signature(**meta_graph_kws),
        meta_graph_pkl_suffix)
    logger.info('meta_graph_pkl_path: {}'.format(meta_graph_pkl_path))

    if calculate_graph or not os.path.exists(meta_graph_pkl_path):
        # we want to calculate the graph or
        # it's not there so we have to
        logger.info('calculating meta_graph...')
        meta_graph_kws_copied = copy.deepcopy(meta_graph_kws)
        with open(msg_ids_path) as f:
            msg_ids = [l.strip() for l in f]

        if isinstance(meta_graph_kws_copied['preprune_secs'], timedelta):
            meta_graph_kws_copied['preprune_secs'] = meta_graph_kws[
                'preprune_secs'].total_seconds()
        g = IU.get_topic_meta_graph(
            interactions,
            msg_ids=msg_ids,
            lda_model=lda_model,
            dictionary=dictionary,
            undirected=False,  # deprecated
            given_topics=given_topics,
            decompose_interactions=False,
            convert_time=convert_time,
            **meta_graph_kws_copied)

        logger.info('pickling...')
        nx.write_gpickle(IU.compactize_meta_graph(g, map_nodes=False),
                         meta_graph_pkl_path)
    else:
        logger.info('loading pickle...')
        g = nx.read_gpickle(meta_graph_pkl_path)

    if print_summary:
        logger.debug(get_summary(g))

    assert g.number_of_nodes() > 0, 'empty graph!'

    if not roots:
        cand_tree_number, cand_tree_percent = get_number_and_percentage(
            g.number_of_nodes(), cand_tree_number, cand_tree_percent)
        if root_sampling_method == 'random':
            root_sampler = RandomSampler(g, timespan)
        elif root_sampling_method == 'upperbound':
            root_sampler = UBSampler(g, U, timespan)
        else:
            logger.info('init AdaptiveSampler...')
            root_sampler = AdaptiveSampler(g, U, timespan)
    else:
        logger.info('Roots given')
        cand_tree_number = len(roots)
        root_sampler = DeterministicSampler(g, roots, timespan)

    logger.info('#roots: {}'.format(cand_tree_number))
    logger.info('#cand_tree_percent: {}'.format(cand_tree_number /
                                                float(g.number_of_nodes())))

    trees = []
    dags = []
    for i in xrange(cand_tree_number):
        logger.info("sampling root...")
        try:
            root, dag = root_sampler.take()
        except IndexError:
            logger.warn('not enough root to take, terminate')
            break
        dags.append(dag)

        start = datetime.now()
        tree = calc_tree(i,
                         root,
                         dag,
                         U,
                         gen_tree_func,
                         gen_tree_kws,
                         print_summary,
                         should_binarize_dag=should_binarize_dag)
        tree.graph['calculation_time'] = (datetime.now() -
                                          start).total_seconds()

        trees.append(tree)

        logger.info("updating sampler states...")
        root_sampler.update(root, tree)

    def make_detailed_path(prefix, suffix):
        return "{}--{}----{}----{}{}.pkl".format(
            prefix, experiment_signature(**gen_tree_kws),
            experiment_signature(**meta_graph_kws),
            experiment_signature(cand_tree_percent=cand_tree_percent,
                                 root_sampling=root_sampling_method), suffix)

    result_pkl_path = make_detailed_path(result_pkl_path_prefix, result_suffix)

    logger.info('result_pkl_path: {}'.format(result_pkl_path))
    pickle.dump(trees,
                open(result_pkl_path, 'w'),
                protocol=pickle.HIGHEST_PROTOCOL)
    if False:
        # for debugging purpose
        pickle.dump(dags,
                    open(result_pkl_path + '.dag', 'w'),
                    protocol=pickle.HIGHEST_PROTOCOL)

    all_paths_pkl_path = make_detailed_path(all_paths_pkl_prefix,
                                            all_paths_pkl_suffix)
    logger.info('Dumping the paths info to {}'.format(all_paths_pkl_path))
    paths_dict = {
        'interactions': interaction_path,
        'meta_graph': meta_graph_pkl_path,
        'result': result_pkl_path,
        'true_events': true_events_path,
        'self': all_paths_pkl_path
    }
    pickle.dump(paths_dict, open(all_paths_pkl_path, 'w'))
    return paths_dict