コード例 #1
0
    def __init__(self, hparams: om.DictConfig):
        super().__init__()

        if not isinstance(hparams, om.DictConfig):
            hparams = om.DictConfig(hparams)
        self.hparams = om.OmegaConf.to_container(hparams, resolve=True)

        # Instantiate datasets (Hydra compat)
        self.dataset = hu.instantiate(hparams.dataset)

        # Instantiate network modules
        self.gen = hu.instantiate(hparams.gen)
        self.dis = hu.instantiate(hparams.dis)
        self.type = hparams.var.type
        if hparams.var.ema:
            self.gen_ema = hu.instantiate(hparams.gen)

        # Instantiate optimizers & schedulers
        self.gen_optim = hu.instantiate(hparams.gen_opt, self.gen.parameters())
        self.dis_optim = hu.instantiate(hparams.dis_opt, self.dis.parameters())

        # Instantiate losses
        self.lambda_recon = hparams.var.lambda_recon
        self.lambda_gp = hparams.var.lambda_gp
        self.recon_loss = hu.instantiate(hparams.recon_loss, self.dis)
        self.adv_loss = hu.instantiate(hparams.adv_loss)
コード例 #2
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        super().__init__(cfg=cfg, trainer=trainer)
        typecheck.set_typecheck_enabled(enabled=False)

        cfg = self._cfg
        self.vocab = AudioToCharWithDursF0Dataset.make_vocab(
            **cfg.train_ds.dataset.vocab)
        self.preprocessor = instantiate(cfg.preprocessor)
        self.embed = GaussianEmbedding(self.vocab, cfg.d_char)
        self.norm_f0 = MaskedInstanceNorm1d(1)
        self.res_f0 = StyleResidual(cfg.d_char, 1, kernel_size=3)
        self.model = instantiate(cfg.model)
        d_out = cfg.model.jasper[-1].filters
        self.proj = nn.Conv1d(d_out, cfg.n_mels, kernel_size=1)
コード例 #3
0
def eval_fhmap(cfg: EvalFhmapConfig) -> None:
    """Evalutate Fourier heat map. The result is saved under outpus/eval_fhmap.

    Note:
        Currently, we only supports the input of even-sized images.

    Args:
        cfg (EvalFhmapConfig): The config of Fourier heat map evaluation.

    """
    # Make config read only.
    # without this, config values might be changed accidentally.
    OmegaConf.set_readonly(cfg, True)  # type: ignore
    logger.info(OmegaConf.to_yaml(cfg))

    # Set constants.
    # device: The device which is used in culculation.
    # cwd: The original current working directory. hydra automatically changes it.
    # weightpath: The path of target trained weight.
    device: Final = cfg.env.device
    cwd: Final[pathlib.Path] = pathlib.Path(hydra.utils.get_original_cwd())
    weightpath: Final[pathlib.Path] = pathlib.Path(cfg.weightpath)

    # Setup datamodule
    root: Final[pathlib.Path] = cwd / "data"
    datamodule = instantiate(cfg.dataset, cfg.batch_size, cfg.env.num_workers,
                             root)
    datamodule.prepare_data()
    datamodule.setup()
    logger.info("datamodule setup: done.")

    # Setup model
    arch = instantiate(cfg.arch, num_classes=datamodule.num_classes)
    arch.load_state_dict(torch.load(weightpath))
    arch = arch.to(device)
    arch.eval()
    logger.info("architecture setup: done.")

    fhmap.eval_fourier_heatmap(
        datamodule.input_size,
        cfg.ignore_edge_size,
        cfg.eps,
        arch,
        datamodule.test_dataset,
        cfg.batch_size,
        cast(torch.device, device),  # needed for passing mypy check.
        cfg.topk,
        pathlib.Path("."),
    )
コード例 #4
0
ファイル: fastpitch.py プロジェクト: quuhua911/NeMo
    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        # TODO(Oktai15): remove it in 1.8.0 version
        if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset":
            dataset = instantiate(cfg.dataset, parser=self.parser)
        elif cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset":
            phon_mode = contextlib.nullcontext()
            if hasattr(self.vocab, "set_phone_prob"):
                phon_mode = self.vocab.set_phone_prob(
                    prob=None if name ==
                    "val" else self.vocab.phoneme_probability)

            with phon_mode:
                dataset = instantiate(
                    cfg.dataset,
                    text_normalizer=self.normalizer,
                    text_normalizer_call_kwargs=self.
                    text_normalizer_call_kwargs,
                    text_tokenizer=self.vocab,
                )
        else:
            # TODO(Oktai15): remove it in 1.8.0 version
            dataset = instantiate(cfg.dataset)

        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)
コード例 #5
0
def train(cfg: DictConfig) -> None:
    tb_logger = CustomTensorBoardLogger('results',
                                        name=cfg.name,
                                        default_hp_metric=False)
    model = instantiate(cfg.lm, cfg, logging_dir=tb_logger.log_dir)

    callbacks = [
        instantiate(fig,
                    pl_module=model,
                    cfg=cfg.figure_details,
                    parent_dir=tb_logger.log_dir)
        for fig in cfg.figures.values()
    ]
    pl_trainer = pl.Trainer(gpus=1, callbacks=callbacks, logger=tb_logger)
    pl_trainer.fit(model)
コード例 #6
0
def test_class_warning() -> None:
    expected = Bar(10, 20, 30, 40)
    with pytest.warns(UserWarning):
        config = OmegaConf.structured(
            {
                "class": "tests.test_utils.Bar",
                "params": {"a": 10, "b": 20, "c": 30, "d": 40},
            }
        )
        assert utils.instantiate(config) == expected

    config = OmegaConf.structured(
        {"cls": "tests.test_utils.Bar", "params": {"a": 10, "b": 20, "c": 30, "d": 40}}
    )
    assert utils.instantiate(config) == expected
コード例 #7
0
ファイル: aligner.py プロジェクト: silencelearner/NeMo
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        super().__init__(cfg=cfg, trainer=trainer)
        typecheck.set_typecheck_enabled(enabled=False)

        cfg = self._cfg
        self.vocab = AudioToCharWithDursF0Dataset.make_vocab(**cfg.train_ds.dataset.vocab)
        self.embed = nn.Embedding(len(self.vocab.labels), cfg.d_char)
        self.preprocessor = instantiate(cfg.preprocessor)
        self.alignment_encoder = instantiate(cfg.alignment_encoder)

        self.forward_sum_loss = ForwardSumLoss()
        self.bin_loss = BinLoss()

        self.bin_start_ratio = cfg.bin_start_ratio
        self.add_bin_loss = False
コード例 #8
0
def get_data_loader(dataset,
                    batch_size,
                    sampler,
                    num_workers,
                    is_distributed,
                    seed,
                    split_path=None,
                    split_key=None,
                    transforms=None):
    """
    Returns data loaders given dataset and sampler configs

    Args:
        dataset         ... hydra config specifying dataset object
        batch_size      ... batch size
        sampler         ... hydra config specifying sampler object
        num_workers     ... number of workers to use in dataloading
        is_distributed  ... whether running in multiprocessing mode, used to wrap sampler using DistributedSamplerWrapper
        seed            ... seed used to coordinate samplers in distributed mode
        split_path      ... path to indices specifying splitting of dataset among train/val/test
        split_key       ... string key to select indices
        transforms      ... list of transforms to apply
    
    Returns: dataloader created with instantiated dataset and (possibly wrapped) sampler
    """
    dataset = instantiate(dataset,
                          transforms=transforms,
                          is_distributed=is_distributed)

    if split_path is not None and split_key is not None:
        split_indices = np.load(split_path, allow_pickle=True)[split_key]
        sampler = instantiate(sampler, split_indices)
    else:
        sampler = instantiate(sampler)

    if is_distributed:
        ngpus = torch.distributed.get_world_size()

        batch_size = int(batch_size / ngpus)

        sampler = DistributedSamplerWrapper(sampler=sampler, seed=seed)

    # TODO: added drop_last, should decide if we want to keep this
    return DataLoader(dataset,
                      sampler=sampler,
                      batch_size=batch_size,
                      num_workers=num_workers,
                      drop_last=False)
コード例 #9
0
ファイル: train.py プロジェクト: vliu15/biggan
def main():
    args = parse_arguments()
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
        config = OmegaConf.create(config)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    generator = instantiate(config.generator).to(device).apply(weights_init)
    discriminator = instantiate(
        config.discriminator).to(device).apply(weights_init)

    g_optimizer = torch.optim.Adam(generator.parameters(), **config.g_optim)
    d_optimizer = torch.optim.Adam(discriminator.parameters(),
                                   **config.d_optim)

    start_epoch = 0
    if config.resume_checkpoint is not None:
        state_dict = torch.load(config.resume_checkpoint)

        generator.load_state_dict(state_dict['g_model_dict'])
        discriminator.load_state_dict(state_dict['d_model_dict'])
        g_optimizer.load_state_dict(state_dict['g_optim_dict'])
        d_optimizer.load_state_dict(state_dict['d_optim_dict'])
        start_epoch = state_dict['epoch']
        print('Starting BigGAN training from checkpoint')
    else:
        print('Starting BigGAN training from random initialization')

    train_dataloader = torch.utils.data.DataLoader(
        instantiate(config.train_dataset),
        collate_fn=collate_fn,
        **config.train_dataloader,
    )
    val_dataloader = torch.utils.data.DataLoader(
        instantiate(config.val_dataset),
        collate_fn=collate_fn,
        **config.val_dataloader,
    )

    train(
        [train_dataloader, val_dataloader],
        [generator, discriminator],
        [g_optimizer, d_optimizer],
        config.train,
        device,
        start_epoch,
    )
コード例 #10
0
def main_worker_function(rank, ngpus_per_node, is_distributed, config):
    # Infer rank from gpu and ngpus, rank is position in gpu list
    gpu = config.gpu_list[rank]

    print("Running main worker function on device: {}".format(gpu))
    torch.cuda.set_device(gpu)

    world_size = ngpus_per_node

    if is_distributed:
        torch.distributed.init_process_group(
            'nccl',
            init_method='env://',
            world_size=world_size,
            rank=rank,
        )

    # Instantiate model and engine
    model = instantiate(config.model).to(gpu)

    # Configure the device to be used for model training and inference
    if is_distributed:
        # Convert model batch norms to synchbatchnorm
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    # Instantiate the engine
    engine = instantiate(config.engine,
                         model=model,
                         rank=rank,
                         gpu=gpu,
                         dump_path=config.dump_path)

    # Configure data loaders
    for task, task_config in config.tasks.items():
        if 'data_loaders' in task_config:
            engine.configure_data_loaders(config.data,
                                          task_config.data_loaders,
                                          is_distributed, config.seed)

    # Configure optimizers
    for task, task_config in config.tasks.items():
        if 'optimizers' in task_config:
            engine.configure_optimizers(task_config.optimizers)

    # Perform tasks
    for task, task_config in config.tasks.items():
        getattr(engine, task)(task_config)
コード例 #11
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        # We use separate preprocessor for training, because we need to pass grads and remove pitch fmax limitation
        self.trg_melspec_fn = instantiate(cfg.preprocessor,
                                          highfreq=None,
                                          use_grads=True)
        self.generator = instantiate(
            cfg.generator,
            n_mel_channels=cfg.preprocessor.nfilt,
            hop_length=cfg.preprocessor.n_window_stride)
        self.mpd = MultiPeriodDiscriminator(
            cfg.discriminator.mpd,
            debug=cfg.debug if "debug" in cfg else False)
        self.mrd = MultiResolutionDiscriminator(
            cfg.discriminator.mrd,
            debug=cfg.debug if "debug" in cfg else False)

        self.discriminator_loss = DiscriminatorLoss()
        self.generator_loss = GeneratorLoss()

        # Reshape MRD resolutions hyperparameter and apply them to MRSTFT loss
        self.stft_resolutions = cfg.discriminator.mrd.resolutions
        self.fft_sizes = [res[0] for res in self.stft_resolutions]
        self.hop_sizes = [res[1] for res in self.stft_resolutions]
        self.win_lengths = [res[2] for res in self.stft_resolutions]
        self.mrstft_loss = MultiResolutionSTFTLoss(self.fft_sizes,
                                                   self.hop_sizes,
                                                   self.win_lengths)
        self.stft_lamb = cfg.stft_lamb

        self.sample_rate = self._cfg.preprocessor.sample_rate
        self.stft_bias = None

        self.input_as_mel = False
        if self._train_dl:
            # TODO(Oktai15): remove it in 1.8.0 version
            if isinstance(self._train_dl.dataset, MelAudioDataset):
                self.input_as_mel = True
            elif isinstance(self._train_dl.dataset, VocoderDataset):
                self.input_as_mel = self._train_dl.dataset.load_precomputed_mel

        self.automatic_optimization = False
コード例 #12
0
    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        kwargs_dict = {}
        if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset":
            kwargs_dict["parser"] = self.parser
        dataset = instantiate(cfg.dataset, **kwargs_dict)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)
コード例 #13
0
 def __init__(self,
              digi_dataset_config,
              true_hits_h5file,
              digi_truth_mapping_file,
              valid_parents=(1, 2, 3),
              parent_type="max",
              transforms=True,
              is_distributed=False):
     """
     Args:
         digi_dataset_config     ... config for dataset for digitized hits
         true_hits_h5file        ... path to h5 dataset file for true hits
         digi_truth_mapping_file ... path to file with a pickled list mapping digitized hit events to true hit events
         valid_parents           ... valid ID values for hit parents
     """
     self.digi_dataset = instantiate(digi_dataset_config,
                                     is_distributed=is_distributed)
     self.mpmt_positions = self.digi_dataset.mpmt_positions
     if transforms:
         self.transforms = self.digi_dataset.transforms
         self.digi_dataset.transforms = None
     else:
         self.transforms = None
     self.truth_dataset = H5TrueDataset(true_hits_h5file,
                                        transforms=None,
                                        digitize_hits=False)
     with open(digi_truth_mapping_file, 'rb') as f:
         self.digi_truth_mapping = pickle.load(f)
     self.valid_parents = np.array(valid_parents)
     if parent_type == "only":
         self.get_digi_hit_parent = self.get_digi_hit_only_parent
     elif parent_type == "max":
         self.get_digi_hit_parent = self.get_digi_hit_max_parent
コード例 #14
0
 def train_dataloader(self):
     train_ds = instantiate(self.dataset_conf.train)
     train_dl = DataLoader(train_ds,
                           self.train_conf.batch_size,
                           shuffle=True,
                           num_workers=self.hparams['num_workers'])
     return train_dl
コード例 #15
0
 def test_dataloader(self):
     test_conf = self.test_conf
     test_ds = instantiate(self.dataset_conf.test)
     test_dl = DataLoader(test_ds,
                          test_conf.batch_size,
                          num_workers=self.hparams['num_workers'])
     return test_dl
コード例 #16
0
ファイル: test_generate.py プロジェクト: nng555/hydra
def test_instantiate_classes(classname: str, params: Any, args: Any,
                             kwargs: Any, expected: Any) -> None:
    full_class = f"{MODULE_NAME}.generated.{classname}Conf"
    schema = OmegaConf.structured(get_class(full_class))
    cfg = OmegaConf.merge(schema, params)
    obj = instantiate(config=cfg, *args, **kwargs)
    assert obj == expected
コード例 #17
0
def create_simple_dataset(conf, transforms):
    # type: (DictConfig, DictConfig) -> JustImages
    transforms = T.Compose([instantiate(v) for k, v in transforms.items()])
    ds = JustImages(conf.root,
                    extensions=tuple(conf.extensions),
                    transform=transforms)
    return ds
コード例 #18
0
ファイル: uniglow.py プロジェクト: vinayphadnis/NeMo
    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")  # TODO
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")  # TODO
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)
コード例 #19
0
    def __init__(self, cfg):
        '''
        全結合ネットワークのAutoEncoder
        
        Parameters:
        --------------------------------------
        config:
            設定ファイルから読み込んだ定数リスト
            ネットワーク構造などのパラメタの設定も設定ファイルで基本行う
        
        '''
        super(AutoEncoder, self).__init__()

        self.config = cfg
        if 'input_shape' in cfg.data:
            self.n_features = cfg.data.input_shape[0]
            if len(cfg.data.feature_cols) != self.n_features:
                self.n_features = len(cfg.data.feature_cols)
        else:
            self.n_features = len(cfg.data.feature_cols)
            #self.config.data.input_shape = [self.n_features]
        self.n_timesteps = cfg.data.window_size
        self.hidden_size = OmegaConf.to_container(cfg.net.hidden_size)
        self.z_dim = cfg.net.z_dim
        self.dropout = cfg.net.dropout

        self.anomaly_scores = None
        self.recon_x = None

        act_f = None
        if 'act_f' in self.config.net:
            act_f = eval(self.config.net.act_f)
        ''' 入力を時間(time)軸にするか、特徴量(features)軸にするか '''
        if 'input_vec' in self.config.data.keys():
            if self.config.data.input_vec == 'time':
                self.encoder = Encoder(self.n_timesteps, self.hidden_size,
                                       self.z_dim, self.n_features,
                                       self.dropout)
                self.decoder = Decoder(self.n_timesteps, self.hidden_size,
                                       self.z_dim, self.n_features,
                                       self.dropout)
            else:
                self.encoder = Encoder(self.n_features, self.hidden_size,
                                       self.z_dim, self.n_timesteps,
                                       self.dropout)
                self.decoder = Decoder(self.n_features, self.hidden_size,
                                       self.z_dim, self.n_timesteps,
                                       self.dropout)
        else:
            raise ValueError(f'config[\'data\'] has not \'input_vec\'')

        self.loss = F.mse_loss
        if 'loss_f' in self.config.net:
            self.loss = eval(self.config.net.loss_f)  # instantiate のほうがいい?

        self.optimizer = instantiate(self.config.optimizer,
                                     params=self.parameters())

        self.save_hyperparameters(
            OmegaConf.to_container(self.config.net, resolve=True))
コード例 #20
0
ファイル: callbacks.py プロジェクト: Jasha10/hydra
    def __init__(self, config: Optional[DictConfig] = None) -> None:
        self.callbacks = []
        from hydra.utils import instantiate

        if config is not None and OmegaConf.select(config, "hydra.callbacks"):
            for params in config.hydra.callbacks.values():
                self.callbacks.append(instantiate(params))
コード例 #21
0
def run_scene_optimizer(args) -> None:
    """ Run GTSFM over images from an Argoverse vehicle log"""
    with initialize_config_module(config_module="gtsfm.configs"):
        # config is relative to the gtsfm module
        cfg = compose(config_name="default_lund_door_set1_config.yaml")
        scene_optimizer: SceneOptimizer = instantiate(cfg.SceneOptimizer)

        loader = ArgoverseDatasetLoader(
            dataset_dir=args.dataset_dir,
            log_id=args.log_id,
            stride=args.stride,
            max_num_imgs=args.max_num_imgs,
            max_lookahead_sec=args.max_lookahead_sec,
            camera_name=args.camera_name,
        )

        sfm_result_graph = scene_optimizer.create_computation_graph(
            len(loader),
            loader.get_valid_pairs(),
            loader.create_computation_graph_for_images(),
            loader.create_computation_graph_for_intrinsics(),
            use_intrinsics_in_verification=True,
            gt_pose_graph=loader.create_computation_graph_for_poses(),
        )

        # create dask client
        cluster = LocalCluster(n_workers=2, threads_per_worker=4)

        with Client(cluster), performance_report(filename="dask-report.html"):
            sfm_result = sfm_result_graph.compute()

        assert isinstance(sfm_result, GtsfmData)
        scene_avg_reproj_error = sfm_result.get_scene_avg_reprojection_error()
        logger.info('Scene avg reproj error: {}'.format(
            str(np.round(scene_avg_reproj_error, 3))))
コード例 #22
0
def build_optimizer(conf: DictConfig, model: nn.Module) -> Optimizer:
    parameters = model.parameters()
    p = conf.params
    if 'weight_decay' in p and p.weight_decay > 0:
        parameters = add_weight_decay(model, p.weight_decay)
        p.weight_decay = 0.0
    return instantiate(conf, parameters)
コード例 #23
0
def setup_ema(conf: DictConfig, model: nn.Module, device=None, master_node=False):
    ema = conf.smoothing
    model_ema = None
    def _update(): pass

    if master_node and ema.enabled:
        model_ema = instantiate(conf.model)
        if not ema.use_cpu:
            model_ema = model_ema.to(device)
        model_ema.load_state_dict(model.state_dict())
        model_ema.requires_grad_(False)

        beta = 1 - ema.alpha ** ema.interval_it

        def _update():
            states = itertools.chain(
                zip(model_ema.parameters(), model.parameters()),
                zip(model_ema.buffers(), model.buffers()))

            with torch.no_grad():
                for t_ema, t in states:
                    # filter out *.bn1.num_batches_tracked
                    if t.dtype != torch.int64:
                        t = t.to(dtype=t_ema.dtype, device=t_ema.device)
                        t_ema.lerp_(t, beta)

    return model_ema, _update
コード例 #24
0
    def __init__(self, cfg: Config) -> None:
        super().__init__()  # type: ignore

        self.logger: Union[LoggerCollection, WandbLogger, Any]
        self.wandb: Run

        self.cfg = cfg

        self.model: LightningModule = instantiate(self.cfg.experiment.model,
                                                  self.cfg)

        self.criterion = MSELoss()

        # Metrics
        self.train_mse = MeanSquaredError()
        self.train_mae = MeanAbsoluteError()
        self.val_mse = MeanSquaredError()
        self.val_mae = MeanAbsoluteError()
        self.test_mse = MeanSquaredError()
        self.test_mae = MeanAbsoluteError()
        self.test_results = []
        train_params = self.cfg.experiment.synop_train_features
        target_param = self.cfg.experiment.target_parameter
        all_params = add_param_to_train_params(train_params, target_param)
        feature_names = list(list(zip(*all_params))[1])
        self.target_param_index = [x
                                   for x in feature_names].index(target_param)
コード例 #25
0
 def test_instantiate_with_missing_values(self, tmpdir):
     """Check if raising error with missing values."""
     ToyModel.configen(tmpdir)
     config = OmegaConf.load(Path(tmpdir).joinpath('model/toy.yaml'))
     validate(config, ModelConfig)
     with pytest.raises(MissingMandatoryValue):
         _ = instantiate(config)
コード例 #26
0
def my_app(cfg: Config) -> None:
    if OmegaConf.get_type(cfg.model) is MlpConfig:
        mlp_cfg = cast(MlpConfig, cfg.model)
        print("using MLP")
        print(f"{mlp_cfg.layers=}")
        print(f"{mlp_cfg.hidden_units=}")
    elif OmegaConf.get_type(cfg.model) is SVMConfig:
        svm_cfg = cast(SVMConfig, cfg.model)
        print("using SVM")
        print(f"{svm_cfg.kernel=}")
        print(f"{svm_cfg.C=}")

    print()

    data_dir: Path = instantiate(cfg.dataset.dir)
    print(data_dir)
    if OmegaConf.get_type(cfg.dataset) is AdultConfig:
        adult_cfg = cast(AdultConfig, cfg.dataset)
        print("using Adult dataset")
        print(f"{adult_cfg.drop_native=}")
    elif OmegaConf.get_type(cfg.dataset) is CmnistConfig:
        cmnist_cfg = cast(CmnistConfig, cfg.dataset)
        print("using CMNIST dataset")
        print(f"{cmnist_cfg.padding=}")

    print()
    print(f"{cfg.seed=}")
    print(f"{cfg.use_wandb=}")
    print(f"{cfg.data_pcnt=}")

    print()
    print("Config as flat dictionary:")
    print(flatten(OmegaConf.to_container(cfg, enum_to_str=True)))
コード例 #27
0
ファイル: test_utils.py プロジェクト: vporta/hydra
def test_interpolation_accessing_parent_deprecated(input_conf: Any,
                                                   passthrough: Dict[str, Any],
                                                   expected: Any,
                                                   recwarn: Any) -> Any:
    input_conf = OmegaConf.create(input_conf)
    obj = utils.instantiate(input_conf.node, **passthrough)
    assert obj == expected
コード例 #28
0
def run_scene_optimizer() -> None:
    """ """
    with initialize_config_module(config_module="gtsfm.configs"):
        # config is relative to the gtsfm module
        cfg = compose(config_name="default_lund_door_set1_config.yaml")
        scene_optimizer: SceneOptimizer = instantiate(cfg.SceneOptimizer)

        loader = OlssonLoader(os.path.join(DATA_ROOT, "set1_lund_door"),
                              image_extension="JPG")

        sfm_result_graph = scene_optimizer.create_computation_graph(
            num_images=len(loader),
            image_pair_indices=loader.get_valid_pairs(),
            image_graph=loader.create_computation_graph_for_images(),
            camera_intrinsics_graph=loader.
            create_computation_graph_for_intrinsics(),
            gt_pose_graph=loader.create_computation_graph_for_poses(),
        )

        # create dask client
        cluster = LocalCluster(n_workers=2, threads_per_worker=4)

        with Client(cluster), performance_report(filename="dask-report.html"):
            sfm_result = sfm_result_graph.compute()

        assert isinstance(sfm_result, GtsfmData)
コード例 #29
0
ファイル: uniglow.py プロジェクト: vinayphadnis/NeMo
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(WaveglowConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.pad_value = self._cfg.preprocessor.params.pad_value
        self.sigma = self._cfg.sigma
        self.audio_to_melspec_precessor = instantiate(self._cfg.preprocessor)
        self.model = UniGlowModule(
            self._cfg.uniglow.n_mel_channels,
            self._cfg.uniglow.n_flows,
            self._cfg.uniglow.n_group,
            self._cfg.uniglow.n_wn_channels,
            self._cfg.uniglow.n_wn_layers,
            self._cfg.uniglow.wn_kernel_size,
            self.get_upsample_factor(),
        )
        self.mode = OperationMode.infer
        self.loss = UniGlowLoss(self._cfg.uniglow.stft_loss_coef)
        self.removed_weightnorm = False
コード例 #30
0
ファイル: talknet.py プロジェクト: silencelearner/NeMo
 def _loader(cfg):
     dataset = instantiate(cfg.dataset)
     return torch.utils.data.DataLoader(  # noqa
         dataset=dataset,
         collate_fn=dataset.collate_fn,
         **cfg.dataloader_params,
     )