Exemplo n.º 1
0
def register_module_dataclass(cs: ConfigStore, registry: Dict[str, Any],
                              group: str) -> None:
    """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc."""
    # note that if `group == model`, we register all model archs, not the model name.
    for k, v in registry.items():
        if v is not None:
            node_ = v(_name=k)
            cs.store(name=k, group=group, node=node_)
Exemplo n.º 2
0
def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
    """cs: config store instance, register common training configs"""

    for k, v in CONFIGS.items():
        try:
            cs.store(name=k, node=v())
        except BaseException:
            logger.error(f"{k} - {v()}")
            raise

    register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task")
    register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model")

    for k, v in REGISTRIES.items():
        register_module_dataclass(cs, v["dataclass_registry"], k)
Exemplo n.º 3
0
def register_params_dataclass(
    cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass]
) -> None:
    """register params dataclass in config store"""
    node_ = data_class(_name=data_class.name())
    cs.store(name=name, group=group, node=node_)
Exemplo n.º 4
0
        different_person = self.loss(sub1_target, sub2_truth)
        logits = softmax(torch.tensor([same_person, different_person]))
        acc = self.accuracy(logits, torch.tensor([1, 0]))

        self.log('same person', same_person)
        self.log('different person', different_person)
        self.log('accuracy', acc)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)


@hydra.main(config_name='eeg')
def main(cfg: EEGLearnerConfig):
    cfg.correlation = True
    cfg.data.load = True

    data = EEGDataModule(cfg, use_criteria=True)
    model = DomainAdaptation(nn.Embedding(164, 10), nn.Embedding(164, 5))
    trainer = pl.Trainer()

    trainer.fit(model, datamodule=data)

    trainer.test(model, datamodule=data)


if __name__ == '__main__':
    cs = ConfigStore()
    cs.store(name='eeg', node=EEGLearnerConfig)
    main()
Exemplo n.º 5
0
"""Module for training path + cycle model on OGB datasets."""

import torch
import torch.optim
import torch

import hydra
from hydra.core.config_store import ConfigStore

from torch_geometric.transforms import Compose
from autobahn.transform import Pathifier, Cyclifier, OGBTransform
from autobahn.experiments.data import OGBDataModule
from autobahn.experiments import combo_models, utils

cs = ConfigStore()
cs.store(name='config_ogb', node=combo_models.OGBTrainingConfiguration)


@hydra.main(config_name='config_ogb', config_path='conf')
def train_with_conf(config: combo_models.OGBTrainingConfiguration):
    utils.ensure_config_defaults(config)

    torch.manual_seed(config.seed)
    transform = Compose([
        OGBTransform(),
        Pathifier(list(config.model.path_lengths)),
        Cyclifier(list(config.model.cycle_lengths))
    ])
    batch_split = max(config.num_gpus, 1)
    dataset = OGBDataModule(config.data,
                            transform=transform,
Exemplo n.º 6
0
import pytorch_lightning as pl
from data.dataset import TradeDataModule
from model.model import TradeModule
import neptune
from pytorch_lightning.loggers.neptune import NeptuneLogger
import hydra
from hydra.core.config_store import ConfigStore
from config.trade_config import TradeConfig


@hydra.main(config_name='trade')
def main(cfg: TradeConfig):
    data = TradeDataModule(**dict(cfg.data))
    model = TradeModule(cfg.data.look_back, data.ds.data.shape[1],
                        data.ds.full_data.shape[1])

    trainer = pl.Trainer(logger=[NeptuneLogger(project_name='yoniosin/Trade')],
                         max_epochs=cfg.max_epochs,
                         fast_dev_run=True)

    trainer.fit(model, datamodule=data)


if __name__ == '__main__':
    # neptune.set_project('yoniosin/Trade')
    cs = ConfigStore()
    cs.store(name='trade', node=TradeConfig)
    main()
Exemplo n.º 7
0
@hydra.main(config_name='config_zinc', config_path='conf')
def train_with_conf(config: combo_models.ZincTrainingConfiguration):
    trainer = utils.make_trainer(config)

    torch.manual_seed(config.seed)

    path_lengths, cycle_lengths = _expand_to_default(
        config.model.path_lengths, config.model.cycle_lengths)

    transform = Compose([Pathifier(path_lengths), Cyclifier(cycle_lengths)])
    batch_split = max(config.num_gpus, 1)
    dataset = ZincDataModule(config.data,
                             transform=transform,
                             batch_size=config.batch_size // batch_split)

    dataset.prepare_data()
    dataset.setup()

    config.model.atom_feature_cardinality = dataset.atom_feature_cardinality
    fixture = combo_models.ZincPathAndCycleModel(config)

    trainer.fit(fixture, dataset)


if __name__ == '__main__':
    from hydra.core.config_store import ConfigStore
    cs = ConfigStore()
    cs.store(name='base_config_zinc',
             node=combo_models.ZincTrainingConfiguration)
    train_with_conf()