Esempio n. 1
0
def main(pid, flags):
    sim_label = 'integrated_view_ecfp8_gconv_psc'
    print(sim_label)

    # Simulation data resource tree
    split_label = "warm" if flags["split_warm"] else "cold_target" if flags["cold_target"] else "cold_drug" if \
        flags["cold_drug"] else "None"
    dataset_lbl = flags["dataset_name"]
    # node_label = "{}_{}_{}_{}_{}".format(dataset_lbl, sim_label, split_label, "eval" if flags["eval"] else "train",
    #                                      date_label)
    node_label = json.dumps({
        'model_family': 'IntView',
        'dataset': dataset_lbl,
        'split': split_label,
        'seeds': '-'.join([str(s) for s in seeds]),
        'mode': "eval" if flags["eval"] else "train",
        'date': date_label
    })
    sim_data = DataNode(label=node_label)
    nodes_list = []
    sim_data.data = nodes_list

    num_cuda_dvcs = torch.cuda.device_count()
    cuda_devices = None if num_cuda_dvcs == 1 else [
        i for i in range(1, num_cuda_dvcs)
    ]

    prot_desc_dict, prot_seq_dict = load_proteins(flags['prot_desc_path'])

    # For searching over multiple seeds
    hparam_search = None

    for seed in seeds:
        summary_writer_creator = lambda: SummaryWriter(
            log_dir="tb_int_view/{}_{}_{}/".format(
                sim_label, seed,
                dt.now().strftime("%Y_%m_%d__%H_%M_%S")))

        # for data collection of this round of simulation.
        data_node = DataNode(label="seed_%d" % seed)
        nodes_list.append(data_node)

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        # load data
        print('-------------------------------------')
        print('Running on dataset: %s' % dataset_lbl)
        print('-------------------------------------')

        data_dict = dict()
        transformers_dict = dict()

        # Data
        data_dict["gconv"] = get_data("GraphConv",
                                      flags,
                                      prot_sequences=prot_seq_dict,
                                      seed=seed)
        transformers_dict["gconv"] = data_dict["gconv"][2]
        data_dict["ecfp8"] = get_data("ECFP8",
                                      flags,
                                      prot_sequences=prot_seq_dict,
                                      seed=seed)
        transformers_dict["ecfp8"] = data_dict["ecfp8"][2]

        tasks = data_dict["gconv"][0]
        # multi-task or single task is determined by the number of tasks w.r.t. the dataset loaded
        flags["tasks"] = tasks

        trainer = IntegratedViewDTI()

        if flags["cv"]:
            k = flags["fold_num"]
            print("{}, {}-Prot: Training scheme: {}-fold cross-validation".
                  format(tasks, sim_label, k))
        else:
            k = 1
            print("{}, {}-Prot: Training scheme: train, validation".format(
                tasks, sim_label) +
                  (", test split" if flags['test'] else " split"))

        if check_data:
            verify_multiview_data(data_dict)
        else:
            if flags["hparam_search"]:
                print("Hyperparameter search enabled: {}".format(
                    flags["hparam_search_alg"]))

                # arguments to callables
                extra_init_args = {
                    "mode": "regression",
                    "cuda_devices": cuda_devices
                }
                extra_data_args = {"flags": flags, "data_dict": data_dict}
                n_iters = 3000
                extra_train_args = {
                    "transformers_dict": transformers_dict,
                    "prot_desc_dict": prot_desc_dict,
                    "tasks": tasks,
                    "is_hsearch": True,
                    "tb_writer": summary_writer_creator
                }

                hparams_conf = get_hparam_config(flags)

                if hparam_search is None:
                    search_alg = {
                        "random_search": RandomSearch,
                        "bayopt_search": BayesianOptSearch
                    }.get(flags["hparam_search_alg"], BayesianOptSearch)
                    search_args = GPMinArgs(n_calls=40, random_state=seed)
                    hparam_search = search_alg(
                        hparam_config=hparams_conf,
                        num_folds=k,
                        initializer=trainer.initialize,
                        data_provider=trainer.data_provider,
                        train_fn=trainer.train,
                        save_model_fn=jova.utils.io.save_model,
                        alg_args=search_args,
                        init_args=extra_init_args,
                        data_args=extra_data_args,
                        train_args=extra_train_args,
                        data_node=data_node,
                        split_label=split_label,
                        sim_label=sim_label,
                        dataset_label=dataset_lbl,
                        results_file="{}_{}_dti_{}.csv".format(
                            flags["hparam_search_alg"], sim_label, date_label))

                stats = hparam_search.fit(model_dir="models",
                                          model_name="".join(tasks))
                print(stats)
                print("Best params = {}".format(stats.best()))
            else:
                invoke_train(trainer, tasks, data_dict, transformers_dict,
                             flags, prot_desc_dict, data_node, sim_label,
                             summary_writer_creator)

    # save simulation data resource tree to file.
    sim_data.to_json(path="./analysis/")
def main(pid, flags):
    if len(flags.views) > 0:
        print("Single views for training:", flags.views)
    else:
        print("No views selected for training")

    for view in flags.views:
        sim_label = "cpi_prediction_baseline_bin"
        print("CUDA={}, view={}".format(cuda, view))

        # Simulation data resource tree
        split_label = flags.split
        dataset_lbl = flags["dataset_name"]
        node_label = "{}_{}_{}_{}_{}".format(dataset_lbl, view, split_label, "eval" if flags["eval"] else "train",
                                             date_label)
        sim_data = DataNode(label=node_label)
        nodes_list = []
        sim_data.data = nodes_list

        num_cuda_dvcs = torch.cuda.device_count()
        cuda_devices = None if num_cuda_dvcs == 1 else [i for i in range(1, num_cuda_dvcs)]

        # Runtime Protein stuff
        prot_desc_dict, prot_seq_dict = load_proteins(flags['prot_desc_path'])
        prot_profile, prot_vocab = load_pickle(file_name=flags.prot_profile), load_pickle(file_name=flags.prot_vocab)
        flags["prot_vocab_size"] = len(prot_vocab)

        flags['mode'] = 'classification'

        # For searching over multiple seeds
        hparam_search = None

        for seed in seeds:
            summary_writer_creator = lambda: SummaryWriter(
                log_dir="tb_cpi_bin/{}_{}_{}/".format(sim_label, seed, dt.now().strftime("%Y_%m_%d__%H_%M_%S")))

            # for data collection of this round of simulation.
            data_node = DataNode(label="seed_%d" % seed)
            nodes_list.append(data_node)

            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

            # load data
            print('-------------------------------------')
            print('Running on dataset: %s' % dataset_lbl)
            print('-------------------------------------')

            data_dict = dict()
            transformers_dict = dict()
            data_key = {"ecfp4": "ECFP4",
                        "ecfp8": "ECFP8",
                        "weave": "Weave",
                        "gconv": "GraphConv",
                        "gnn": "GNN"}.get(view)
            data_dict[view] = get_data(data_key, flags, prot_sequences=prot_seq_dict, seed=seed)
            transformers_dict[view] = data_dict[view][2]
            flags["gnn_fingerprint"] = data_dict[view][3]

            tasks = data_dict[view][0]
            flags["tasks"] = tasks

            trainer = CPIBaseline()

            if flags["cv"]:
                k = flags["fold_num"]
                print("{}, {}-Prot: Training scheme: {}-fold cross-validation".format(tasks, view, k))
            else:
                k = 1
                print("{}, {}-Prot: Training scheme: train, validation".format(tasks, view)
                      + (", test split" if flags['test'] else " split"))

            if flags["hparam_search"]:
                print("Hyperparameter search enabled: {}".format(flags["hparam_search_alg"]))

                # arguments to callables
                extra_init_args = {"mode": "regression",
                                   "cuda_devices": cuda_devices,
                                   "protein_profile": prot_profile}
                extra_data_args = {"flags": flags,
                                   "data_dict": data_dict}
                extra_train_args = {"transformers_dict": transformers_dict,
                                    "prot_desc_dict": prot_desc_dict,
                                    "tasks": tasks,
                                    "n_iters": 3000,
                                    "is_hsearch": True,
                                    "view_lbl": view,
                                    "tb_writer": summary_writer_creator}

                hparams_conf = get_hparam_config(flags, view)
                if hparam_search is None:
                    search_alg = {"random_search": RandomSearch,
                                  "bayopt_search": BayesianOptSearch}.get(flags["hparam_search_alg"],
                                                                          BayesianOptSearch)
                    search_args = GPMinArgs(n_calls=40, random_state=seed)
                    hparam_search = search_alg(hparam_config=hparams_conf,
                                               num_folds=k,
                                               initializer=trainer.initialize,
                                               data_provider=trainer.data_provider,
                                               train_fn=trainer.train,
                                               save_model_fn=jova.utils.io.save_model,
                                               alg_args=search_args,
                                               init_args=extra_init_args,
                                               data_args=extra_data_args,
                                               train_args=extra_train_args,
                                               data_node=data_node,
                                               split_label=split_label,
                                               sim_label=sim_label,
                                               dataset_label=dataset_lbl,
                                               results_file="{}_{}_dti_{}.csv".format(
                                                   flags["hparam_search_alg"], sim_label, date_label))

                stats = hparam_search.fit(model_dir="models", model_name="".join(tasks))
                print(stats)
                print("Best params = {}".format(stats.best()))
            else:
                invoke_train(trainer, tasks, data_dict, transformers_dict, flags, prot_desc_dict, data_node,
                             view, prot_profile, summary_writer_creator)

        # save simulation data resource tree to file.
        sim_data.to_json(path="./analysis/")
def main(pid, flags):
    sim_label = "two_way_attn_dti_baseline"
    print("CUDA={}, view={}".format(cuda, sim_label))

    # Simulation data resource tree
    split_label = flags.split
    dataset_lbl = flags["dataset_name"]
    # node_label = "{}_{}_{}_{}_{}".format(dataset_lbl, sim_label, split_label, "eval" if flags["eval"] else "train",
    #                                      date_label)

    if flags['eval']:
        mode = 'eval'
    elif flags['explain']:
        mode = 'explain'
    else:
        mode = 'train'
    node_label = json.dumps({
        'model_family': '2way-dti',
        'dataset': dataset_lbl,
        'split': split_label,
        'cv': flags["cv"],
        'mode': mode,
        'seeds': '-'.join([str(s) for s in seeds]),
        'date': date_label
    })
    sim_data = DataNode(label='_'.join(
        [sim_label, dataset_lbl, split_label, mode, date_label]),
                        metadata=node_label)
    nodes_list = []
    sim_data.data = nodes_list

    num_cuda_dvcs = torch.cuda.device_count()
    cuda_devices = None if num_cuda_dvcs == 1 else [
        i for i in range(1, num_cuda_dvcs)
    ]

    # Runtime Protein stuff
    prot_desc_dict, prot_seq_dict = load_proteins(flags['prot_desc_path'])
    prot_profile, prot_vocab = load_pickle(
        file_name=flags.prot_profile), load_pickle(file_name=flags.prot_vocab)
    # pretrained_embeddings = load_numpy_array(flags.protein_embeddings)
    flags["prot_vocab_size"] = len(prot_vocab)
    # flags["embeddings_dim"] = pretrained_embeddings.shape[-1]

    # set attention hook's protein information
    two_way_attn.protein_profile = prot_profile
    two_way_attn.protein_vocabulary = prot_vocab
    two_way_attn.protein_sequences = prot_seq_dict

    # For searching over multiple seeds
    hparam_search = None

    for seed in seeds:
        # for data collection of this round of simulation.
        data_node = DataNode(label="seed_%d" % seed)
        nodes_list.append(data_node)

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        # load data
        print('-------------------------------------')
        print('Running on dataset: %s' % dataset_lbl)
        print('-------------------------------------')

        data_dict = dict()
        transformers_dict = dict()

        # Data
        if use_weave:
            data_dict["weave"] = get_data("Weave",
                                          flags,
                                          prot_sequences=prot_seq_dict,
                                          seed=seed)
            transformers_dict["weave"] = data_dict["weave"][2]
        if use_gconv:
            data_dict["gconv"] = get_data("GraphConv",
                                          flags,
                                          prot_sequences=prot_seq_dict,
                                          seed=seed)
            transformers_dict["gconv"] = data_dict["gconv"][2]
        if use_gnn:
            data_dict["gnn"] = get_data("GNN",
                                        flags,
                                        prot_sequences=prot_seq_dict,
                                        seed=seed)
            transformers_dict["gnn"] = data_dict["gnn"][2]

        tasks = data_dict[list(data_dict.keys())[0]][0]

        trainer = TwoWayAttnBaseline()

        if flags["cv"]:
            k = flags["fold_num"]
            print("{}, {}-Prot: Training scheme: {}-fold cross-validation".
                  format(tasks, sim_label, k))
        else:
            k = 1
            print("{}, {}-Prot: Training scheme: train, validation".format(
                tasks, sim_label) +
                  (", test split" if flags['test'] else " split"))

        if check_data:
            verify_multiview_data(data_dict)
        else:
            if flags["hparam_search"]:
                print("Hyperparameter search enabled: {}".format(
                    flags["hparam_search_alg"]))

                # arguments to callables
                extra_init_args = {
                    "mode": "regression",
                    "cuda_devices": cuda_devices,
                    "protein_profile": prot_profile
                }
                extra_data_args = {"flags": flags, "data_dict": data_dict}
                extra_train_args = {
                    "transformers_dict": transformers_dict,
                    "prot_desc_dict": prot_desc_dict,
                    "tasks": tasks,
                    "is_hsearch": True,
                    "n_iters": 3000
                }

                hparams_conf = get_hparam_config(flags)

                if hparam_search is None:
                    search_alg = {
                        "random_search": RandomSearch,
                        "bayopt_search": BayesianOptSearch
                    }.get(flags["hparam_search_alg"], BayesianOptSearch)
                    search_args = GPMinArgs(n_calls=20)
                    min_opt = "gbrt"
                    hparam_search = search_alg(
                        hparam_config=hparams_conf,
                        num_folds=k,
                        initializer=trainer.initialize,
                        data_provider=trainer.data_provider,
                        train_fn=trainer.train,
                        save_model_fn=jova.utils.io.save_model,
                        init_args=extra_init_args,
                        data_args=extra_data_args,
                        train_args=extra_train_args,
                        alg_args=search_args,
                        data_node=data_node,
                        split_label=split_label,
                        sim_label=sim_label,
                        minimizer=min_opt,
                        dataset_label=dataset_lbl,
                        results_file="{}_{}_dti_{}_{}.csv".format(
                            flags["hparam_search_alg"], sim_label, date_label,
                            min_opt))

                stats = hparam_search.fit()
                print(stats)
                print("Best params = {}".format(stats.best()))
            else:
                invoke_train(trainer, tasks, data_dict, transformers_dict,
                             flags, prot_desc_dict, data_node, sim_label,
                             prot_profile)

    # save simulation data resource tree to file.
    sim_data.to_json(path="./analysis/")
Esempio n. 4
0
def main(flags):
    if len(flags["views"]) > 0:
        print("Single views for training: {}, num={}".format(
            flags["views"], len(flags["views"])))
    else:
        print("No views selected for training")

    for view in flags["views"]:
        dataset_lbl = flags["dataset_name"]
        cview, pview = view
        sim_label = "MF_{}_{}_{}".format(dataset_lbl, cview, pview)
        print(sim_label)

        # Simulation data resource tree
        split_label = flags.split
        # node_label = "{}_{}_{}_{}_{}_{}".format(dataset_lbl, cview, pview, split_label,
        #                                         "eval" if flags["eval"] else "train", date_label)
        node_label = json.dumps({
            'model_family': 'mf',
            'dataset': dataset_lbl,
            'cview': cview,
            'pview': pview,
            'split': split_label,
            'seeds': '-'.join([str(s) for s in seeds]),
            'mode': "eval" if flags["eval"] else "train",
            'date': date_label
        })
        sim_data = DataNode(label=node_label)
        nodes_list = []
        sim_data.data = nodes_list

        prot_desc_dict, prot_seq_dict = load_proteins(flags['prot_desc_path'])

        # For searching over multiple seeds
        hparam_search = None

        for seed in seeds:
            # for data collection of this round of simulation.
            data_node = DataNode(label="seed_%d" % seed)
            nodes_list.append(data_node)

            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

            # load data
            print('-------------------------------------')
            print('Running on dataset: %s' % dataset_lbl)
            print('-------------------------------------')

            data_key = {"ecfp4": "MF_ECFP4", "ecfp8": "MF_ECFP8"}.get(cview)
            flags['splitting_alg'] = 'random'
            flags['cv'] = False
            flags['test'] = False
            flags['fold_num'] = 1
            data = get_data(data_key,
                            flags,
                            prot_sequences=prot_seq_dict,
                            seed=seed)
            transformer = data[2]
            tasks = data[0]
            flags["tasks"] = tasks

            trainer = MF()

            k = 1
            print("{}, {}-{}: Training scheme: train, validation".format(
                tasks, cview, pview) +
                  (", test split" if flags['test'] else " split"))

            if flags["hparam_search"]:
                print("Hyperparameter search enabled: {}".format(
                    flags["hparam_search_alg"]))

                # arguments to callables
                extra_init_args = {}
                extra_data_args = {"flags": flags, "data": data}
                n_iters = 3000
                extra_train_args = {
                    "transformer": transformer,
                    "tasks": tasks,
                    "is_hsearch": True
                }

                hparams_conf = get_hparam_config(flags)

                if hparam_search is None:
                    search_alg = {
                        "random_search": RandomSearch,
                        "bayopt_search": BayesianOptSearch
                    }.get(flags["hparam_search_alg"], BayesianOptSearch)
                    search_args = GPMinArgs(n_calls=40, random_state=seed)
                    hparam_search = search_alg(
                        hparam_config=hparams_conf,
                        num_folds=k,
                        initializer=trainer.initialize,
                        data_provider=trainer.data_provider,
                        train_fn=trainer.train,
                        save_model_fn=save_mf_model_and_feats,
                        alg_args=search_args,
                        init_args=extra_init_args,
                        data_args=extra_data_args,
                        train_args=extra_train_args,
                        data_node=data_node,
                        split_label=split_label,
                        sim_label=sim_label,
                        dataset_label=dataset_lbl,
                        results_file="{}_{}_dti_{}.csv".format(
                            flags["hparam_search_alg"], sim_label, date_label))

                stats = hparam_search.fit(model_dir="models",
                                          model_name="".join(tasks))
                print(stats)
                print("Best params = {}".format(stats.best()))
            else:
                invoke_train(trainer, tasks, data, transformer, flags,
                             data_node, sim_label, dataset_lbl)

        # save simulation data resource tree to file.
        sim_data.to_json(path="./analysis/")
Esempio n. 5
0
def main(id, flags):
    if len(flags["views"]) > 0:
        print("Single views for training: {}, num={}".format(flags["views"], len(flags["views"])))
    else:
        print("No views selected for training")

    for view in flags["views"]:
        split_label = flags.split
        dataset_lbl = flags["dataset_name"]
        mode = "eval" if flags["eval"] else "train"
        if flags.cv:
            mode += 'cv'
        cview, pview = view
        sim_label = f"{dataset_lbl}_{split_label}_single_view_{cview}_{pview}_{mode}"
        print("CUDA={}, {}".format(cuda, sim_label))

        # Simulation data resource tree
        # node_label = "{}_{}_{}_{}_{}_{}".format(dataset_lbl, cview, pview, split_label, mode, date_label)
        node_label = json.dumps({'model_family': 'singleview',
                                 'dataset': dataset_lbl,
                                 'cview': cview,
                                 'pview': pview,
                                 'split': split_label,
                                 'mode': mode,
                                 'seeds': '-'.join([str(s) for s in seeds]),
                                 'date': date_label})
        sim_data = DataNode(label=node_label)
        nodes_list = []
        sim_data.data = nodes_list

        num_cuda_dvcs = torch.cuda.device_count()
        cuda_devices = None if num_cuda_dvcs == 1 else [i for i in range(1, num_cuda_dvcs)]

        prot_desc_dict, prot_seq_dict = load_proteins(flags['prot_desc_path'])
        prot_profile = load_pickle(file_name=flags['prot_profile'])
        prot_vocab = load_pickle(file_name=flags['prot_vocab'])
        flags["prot_vocab_size"] = len(prot_vocab)

        # For searching over multiple seeds
        hparam_search = None

        for seed in seeds:
            summary_writer_creator = lambda: SummaryWriter(
                log_dir="tb_singles_hs/{}_{}_{}/".format(sim_label, seed, dt.now().strftime("%Y_%m_%d__%H_%M_%S")))

            # for data collection of this round of simulation.
            data_node = DataNode(label="seed_%d" % seed)
            nodes_list.append(data_node)

            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

            # load data
            print('-------------------------------------')
            print('Running on dataset: %s' % dataset_lbl)
            print('-------------------------------------')

            data_dict = dict()
            transformers_dict = dict()
            data_key = {"ecfp4": "ECFP4",
                        "ecfp8": "ECFP8",
                        "weave": "Weave",
                        "gconv": "GraphConv",
                        "gnn": "GNN"}.get(cview)
            data_dict[cview] = get_data(data_key, flags, prot_sequences=prot_seq_dict, seed=seed)
            transformers_dict[cview] = data_dict[cview][2]
            flags["gnn_fingerprint"] = data_dict[cview][3]

            tasks = data_dict[cview][0]
            # multi-task or single task is determined by the number of tasks w.r.t. the dataset loaded
            flags["tasks"] = tasks

            trainer = SingleViewDTI()

            if flags["cv"]:
                k = flags["fold_num"]
                print("{}, {}-{}: Training scheme: {}-fold cross-validation".format(tasks, cview, pview, k))
            else:
                k = 1
                print("{}, {}-{}: Training scheme: train, validation".format(tasks, cview, pview)
                      + (", test split" if flags['test'] else " split"))

            if flags["hparam_search"]:
                print("Hyperparameter search enabled: {}".format(flags["hparam_search_alg"]))

                # arguments to callables
                extra_init_args = {"mode": "regression",
                                   "cuda_devices": cuda_devices,
                                   "protein_profile": prot_profile}
                extra_data_args = {"flags": flags,
                                   "data_dict": data_dict}
                n_iters = 3000
                extra_train_args = {"transformers_dict": transformers_dict,
                                    "prot_desc_dict": prot_desc_dict,
                                    "tasks": tasks,
                                    "n_iters": n_iters,
                                    "is_hsearch": True,
                                    "view": view,
                                    "tb_writer": summary_writer_creator}

                hparams_conf = get_hparam_config(flags, cview, pview)

                if hparam_search is None:
                    search_alg = {"random_search": RandomSearch,
                                  "bayopt_search": BayesianOptSearch}.get(flags["hparam_search_alg"],
                                                                          BayesianOptSearch)
                    search_args = GPMinArgs(n_calls=40, random_state=seed)
                    hparam_search = search_alg(hparam_config=hparams_conf,
                                               num_folds=k,
                                               initializer=trainer.initialize,
                                               data_provider=trainer.data_provider,
                                               train_fn=trainer.train,
                                               save_model_fn=jova.utils.io.save_model,
                                               alg_args=search_args,
                                               init_args=extra_init_args,
                                               data_args=extra_data_args,
                                               train_args=extra_train_args,
                                               data_node=data_node,
                                               split_label=split_label,
                                               sim_label=sim_label,
                                               dataset_label=dataset_lbl,
                                               results_file="{}_{}_dti_{}.csv".format(
                                                   flags["hparam_search_alg"], sim_label, date_label))

                stats = hparam_search.fit(model_dir="models", model_name="".join(tasks))
                print(stats)
                print("Best params = {}".format(stats.best()))
            else:
                invoke_train(trainer, tasks, data_dict, transformers_dict, flags, prot_desc_dict, data_node, view,
                             prot_profile, summary_writer_creator)

        # save simulation data resource tree to file.
        sim_data.to_json(path="./analysis/")
Esempio n. 6
0
def main(pid, flags):
    if len(flags.views) > 0:
        print("Single views for training:", flags.views)
    else:
        print("No views selected for training")

    for view in flags.views:
        sim_label = "cpi_prediction_baseline"
        print("CUDA={}, view={}".format(cuda, view))

        # Simulation data resource tree
        split_label = flags.split
        dataset_lbl = flags["dataset_name"]
        if flags['eval']:
            mode = 'eval'
        elif flags['explain']:
            mode = 'explain'
        else:
            mode = 'train'
        node_label = json.dumps({
            'model_family': 'cpi',
            'dataset': dataset_lbl,
            'cview': 'gnn',
            'pview': 'pcnna',
            'split': split_label,
            'cv': flags["cv"],
            'seeds': '-'.join([str(s) for s in seeds]),
            'mode': mode,
            'date': date_label
        })
        sim_data = DataNode(label='_'.join(
            [sim_label, dataset_lbl, split_label, mode, date_label]),
                            metadata=node_label)
        nodes_list = []
        sim_data.data = nodes_list

        num_cuda_dvcs = torch.cuda.device_count()
        cuda_devices = None if num_cuda_dvcs == 1 else [
            i for i in range(1, num_cuda_dvcs)
        ]

        # Runtime Protein stuff
        prot_desc_dict, prot_seq_dict = load_proteins(flags['prot_desc_path'])
        prot_profile, prot_vocab = load_pickle(
            file_name=flags.prot_profile), load_pickle(
                file_name=flags.prot_vocab)
        flags["prot_vocab_size"] = len(prot_vocab)

        # For searching over multiple seeds
        hparam_search = None

        for seed in seeds:
            # for data collection of this round of simulation.
            data_node = DataNode(label="seed_%d" % seed)
            nodes_list.append(data_node)

            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

            # load data
            print('-------------------------------------')
            print('Running on dataset: %s' % dataset_lbl)
            print('-------------------------------------')

            data_dict = dict()
            transformers_dict = dict()
            data_key = {
                "ecfp4": "ECFP4",
                "ecfp8": "ECFP8",
                "weave": "Weave",
                "gconv": "GraphConv",
                "gnn": "GNN"
            }.get(view)
            data_dict[view] = get_data(data_key,
                                       flags,
                                       prot_sequences=prot_seq_dict,
                                       seed=seed)
            transformers_dict[view] = data_dict[view][2]
            flags["gnn_fingerprint"] = data_dict[view][3]

            tasks = data_dict[view][0]
            flags["tasks"] = tasks

            trainer = CPIBaseline()

            if flags["cv"]:
                k = flags["fold_num"]
                print("{}, {}-Prot: Training scheme: {}-fold cross-validation".
                      format(tasks, view, k))
            else:
                k = 1
                print("{}, {}-Prot: Training scheme: train, validation".format(
                    tasks, view) +
                      (", test split" if flags['test'] else " split"))

            if flags["hparam_search"]:
                print("Hyperparameter search enabled: {}".format(
                    flags["hparam_search_alg"]))

                # arguments to callables
                extra_init_args = {
                    "mode": "regression",
                    "cuda_devices": cuda_devices,
                    "protein_profile": prot_profile
                }
                extra_data_args = {"flags": flags, "data_dict": data_dict}
                extra_train_args = {
                    "transformers_dict": transformers_dict,
                    "prot_desc_dict": prot_desc_dict,
                    "tasks": tasks,
                    "n_iters": 3000,
                    "is_hsearch": True,
                    "view_lbl": view
                }

                hparams_conf = get_hparam_config(flags, view)
                if hparam_search is None:
                    search_alg = {
                        "random_search": RandomSearch,
                        "bayopt_search": BayesianOptSearch
                    }.get(flags["hparam_search_alg"], BayesianOptSearch)
                    search_args = GPMinArgs(n_calls=20)
                    min_opt = "gbrt"
                    hparam_search = search_alg(
                        hparam_config=hparams_conf,
                        num_folds=k,
                        initializer=trainer.initialize,
                        data_provider=trainer.data_provider,
                        train_fn=trainer.train,
                        save_model_fn=jova.utils.io.save_model,
                        init_args=extra_init_args,
                        data_args=extra_data_args,
                        train_args=extra_train_args,
                        alg_args=search_args,
                        data_node=data_node,
                        split_label=split_label,
                        sim_label=sim_label,
                        minimizer=min_opt,
                        dataset_label=dataset_lbl,
                        results_file="{}_{}_dti_{}_{}.csv".format(
                            flags["hparam_search_alg"], sim_label, date_label,
                            min_opt))

                stats = hparam_search.fit()
                print(stats)
                print("Best params = {}".format(stats.best()))
            else:
                invoke_train(trainer, tasks, data_dict, transformers_dict,
                             flags, prot_desc_dict, data_node, view,
                             prot_profile)

        # save simulation data resource tree to file.
        sim_data.to_json(path="./analysis/")