def __init__(self, hparams: om.DictConfig): super().__init__() if not isinstance(hparams, om.DictConfig): hparams = om.DictConfig(hparams) self.hparams = om.OmegaConf.to_container(hparams, resolve=True) # Instantiate datasets (Hydra compat) self.dataset = hu.instantiate(hparams.dataset) # Instantiate network modules self.gen = hu.instantiate(hparams.gen) self.dis = hu.instantiate(hparams.dis) self.type = hparams.var.type if hparams.var.ema: self.gen_ema = hu.instantiate(hparams.gen) # Instantiate optimizers & schedulers self.gen_optim = hu.instantiate(hparams.gen_opt, self.gen.parameters()) self.dis_optim = hu.instantiate(hparams.dis_opt, self.dis.parameters()) # Instantiate losses self.lambda_recon = hparams.var.lambda_recon self.lambda_gp = hparams.var.lambda_gp self.recon_loss = hu.instantiate(hparams.recon_loss, self.dis) self.adv_loss = hu.instantiate(hparams.adv_loss)
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg=cfg, trainer=trainer) typecheck.set_typecheck_enabled(enabled=False) cfg = self._cfg self.vocab = AudioToCharWithDursF0Dataset.make_vocab( **cfg.train_ds.dataset.vocab) self.preprocessor = instantiate(cfg.preprocessor) self.embed = GaussianEmbedding(self.vocab, cfg.d_char) self.norm_f0 = MaskedInstanceNorm1d(1) self.res_f0 = StyleResidual(cfg.d_char, 1, kernel_size=3) self.model = instantiate(cfg.model) d_out = cfg.model.jasper[-1].filters self.proj = nn.Conv1d(d_out, cfg.n_mels, kernel_size=1)
def eval_fhmap(cfg: EvalFhmapConfig) -> None: """Evalutate Fourier heat map. The result is saved under outpus/eval_fhmap. Note: Currently, we only supports the input of even-sized images. Args: cfg (EvalFhmapConfig): The config of Fourier heat map evaluation. """ # Make config read only. # without this, config values might be changed accidentally. OmegaConf.set_readonly(cfg, True) # type: ignore logger.info(OmegaConf.to_yaml(cfg)) # Set constants. # device: The device which is used in culculation. # cwd: The original current working directory. hydra automatically changes it. # weightpath: The path of target trained weight. device: Final = cfg.env.device cwd: Final[pathlib.Path] = pathlib.Path(hydra.utils.get_original_cwd()) weightpath: Final[pathlib.Path] = pathlib.Path(cfg.weightpath) # Setup datamodule root: Final[pathlib.Path] = cwd / "data" datamodule = instantiate(cfg.dataset, cfg.batch_size, cfg.env.num_workers, root) datamodule.prepare_data() datamodule.setup() logger.info("datamodule setup: done.") # Setup model arch = instantiate(cfg.arch, num_classes=datamodule.num_classes) arch.load_state_dict(torch.load(weightpath)) arch = arch.to(device) arch.eval() logger.info("architecture setup: done.") fhmap.eval_fourier_heatmap( datamodule.input_size, cfg.ignore_edge_size, cfg.eps, arch, datamodule.test_dataset, cfg.batch_size, cast(torch.device, device), # needed for passing mypy check. cfg.topk, pathlib.Path("."), )
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") if "dataloader_params" not in cfg or not isinstance( cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True") with open_dict(cfg.dataloader_params): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to False!!!" ) elif not shuffle_should_be and cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to True!!!") # TODO(Oktai15): remove it in 1.8.0 version if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset": dataset = instantiate(cfg.dataset, parser=self.parser) elif cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset": phon_mode = contextlib.nullcontext() if hasattr(self.vocab, "set_phone_prob"): phon_mode = self.vocab.set_phone_prob( prob=None if name == "val" else self.vocab.phoneme_probability) with phon_mode: dataset = instantiate( cfg.dataset, text_normalizer=self.normalizer, text_normalizer_call_kwargs=self. text_normalizer_call_kwargs, text_tokenizer=self.vocab, ) else: # TODO(Oktai15): remove it in 1.8.0 version dataset = instantiate(cfg.dataset) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def train(cfg: DictConfig) -> None: tb_logger = CustomTensorBoardLogger('results', name=cfg.name, default_hp_metric=False) model = instantiate(cfg.lm, cfg, logging_dir=tb_logger.log_dir) callbacks = [ instantiate(fig, pl_module=model, cfg=cfg.figure_details, parent_dir=tb_logger.log_dir) for fig in cfg.figures.values() ] pl_trainer = pl.Trainer(gpus=1, callbacks=callbacks, logger=tb_logger) pl_trainer.fit(model)
def test_class_warning() -> None: expected = Bar(10, 20, 30, 40) with pytest.warns(UserWarning): config = OmegaConf.structured( { "class": "tests.test_utils.Bar", "params": {"a": 10, "b": 20, "c": 30, "d": 40}, } ) assert utils.instantiate(config) == expected config = OmegaConf.structured( {"cls": "tests.test_utils.Bar", "params": {"a": 10, "b": 20, "c": 30, "d": 40}} ) assert utils.instantiate(config) == expected
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg=cfg, trainer=trainer) typecheck.set_typecheck_enabled(enabled=False) cfg = self._cfg self.vocab = AudioToCharWithDursF0Dataset.make_vocab(**cfg.train_ds.dataset.vocab) self.embed = nn.Embedding(len(self.vocab.labels), cfg.d_char) self.preprocessor = instantiate(cfg.preprocessor) self.alignment_encoder = instantiate(cfg.alignment_encoder) self.forward_sum_loss = ForwardSumLoss() self.bin_loss = BinLoss() self.bin_start_ratio = cfg.bin_start_ratio self.add_bin_loss = False
def get_data_loader(dataset, batch_size, sampler, num_workers, is_distributed, seed, split_path=None, split_key=None, transforms=None): """ Returns data loaders given dataset and sampler configs Args: dataset ... hydra config specifying dataset object batch_size ... batch size sampler ... hydra config specifying sampler object num_workers ... number of workers to use in dataloading is_distributed ... whether running in multiprocessing mode, used to wrap sampler using DistributedSamplerWrapper seed ... seed used to coordinate samplers in distributed mode split_path ... path to indices specifying splitting of dataset among train/val/test split_key ... string key to select indices transforms ... list of transforms to apply Returns: dataloader created with instantiated dataset and (possibly wrapped) sampler """ dataset = instantiate(dataset, transforms=transforms, is_distributed=is_distributed) if split_path is not None and split_key is not None: split_indices = np.load(split_path, allow_pickle=True)[split_key] sampler = instantiate(sampler, split_indices) else: sampler = instantiate(sampler) if is_distributed: ngpus = torch.distributed.get_world_size() batch_size = int(batch_size / ngpus) sampler = DistributedSamplerWrapper(sampler=sampler, seed=seed) # TODO: added drop_last, should decide if we want to keep this return DataLoader(dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers, drop_last=False)
def main(): args = parse_arguments() with open(args.config, 'r') as f: config = yaml.safe_load(f) config = OmegaConf.create(config) device = 'cuda' if torch.cuda.is_available() else 'cpu' generator = instantiate(config.generator).to(device).apply(weights_init) discriminator = instantiate( config.discriminator).to(device).apply(weights_init) g_optimizer = torch.optim.Adam(generator.parameters(), **config.g_optim) d_optimizer = torch.optim.Adam(discriminator.parameters(), **config.d_optim) start_epoch = 0 if config.resume_checkpoint is not None: state_dict = torch.load(config.resume_checkpoint) generator.load_state_dict(state_dict['g_model_dict']) discriminator.load_state_dict(state_dict['d_model_dict']) g_optimizer.load_state_dict(state_dict['g_optim_dict']) d_optimizer.load_state_dict(state_dict['d_optim_dict']) start_epoch = state_dict['epoch'] print('Starting BigGAN training from checkpoint') else: print('Starting BigGAN training from random initialization') train_dataloader = torch.utils.data.DataLoader( instantiate(config.train_dataset), collate_fn=collate_fn, **config.train_dataloader, ) val_dataloader = torch.utils.data.DataLoader( instantiate(config.val_dataset), collate_fn=collate_fn, **config.val_dataloader, ) train( [train_dataloader, val_dataloader], [generator, discriminator], [g_optimizer, d_optimizer], config.train, device, start_epoch, )
def main_worker_function(rank, ngpus_per_node, is_distributed, config): # Infer rank from gpu and ngpus, rank is position in gpu list gpu = config.gpu_list[rank] print("Running main worker function on device: {}".format(gpu)) torch.cuda.set_device(gpu) world_size = ngpus_per_node if is_distributed: torch.distributed.init_process_group( 'nccl', init_method='env://', world_size=world_size, rank=rank, ) # Instantiate model and engine model = instantiate(config.model).to(gpu) # Configure the device to be used for model training and inference if is_distributed: # Convert model batch norms to synchbatchnorm model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[gpu], find_unused_parameters=True) # Instantiate the engine engine = instantiate(config.engine, model=model, rank=rank, gpu=gpu, dump_path=config.dump_path) # Configure data loaders for task, task_config in config.tasks.items(): if 'data_loaders' in task_config: engine.configure_data_loaders(config.data, task_config.data_loaders, is_distributed, config.seed) # Configure optimizers for task, task_config in config.tasks.items(): if 'optimizers' in task_config: engine.configure_optimizers(task_config.optimizers) # Perform tasks for task, task_config in config.tasks.items(): getattr(engine, task)(task_config)
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # Convert to Hydra 1.0 compatible DictConfig cfg = model_utils.convert_model_config_to_dict_config(cfg) cfg = model_utils.maybe_update_config_version(cfg) super().__init__(cfg=cfg, trainer=trainer) self.audio_to_melspec_precessor = instantiate(cfg.preprocessor) # We use separate preprocessor for training, because we need to pass grads and remove pitch fmax limitation self.trg_melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True) self.generator = instantiate( cfg.generator, n_mel_channels=cfg.preprocessor.nfilt, hop_length=cfg.preprocessor.n_window_stride) self.mpd = MultiPeriodDiscriminator( cfg.discriminator.mpd, debug=cfg.debug if "debug" in cfg else False) self.mrd = MultiResolutionDiscriminator( cfg.discriminator.mrd, debug=cfg.debug if "debug" in cfg else False) self.discriminator_loss = DiscriminatorLoss() self.generator_loss = GeneratorLoss() # Reshape MRD resolutions hyperparameter and apply them to MRSTFT loss self.stft_resolutions = cfg.discriminator.mrd.resolutions self.fft_sizes = [res[0] for res in self.stft_resolutions] self.hop_sizes = [res[1] for res in self.stft_resolutions] self.win_lengths = [res[2] for res in self.stft_resolutions] self.mrstft_loss = MultiResolutionSTFTLoss(self.fft_sizes, self.hop_sizes, self.win_lengths) self.stft_lamb = cfg.stft_lamb self.sample_rate = self._cfg.preprocessor.sample_rate self.stft_bias = None self.input_as_mel = False if self._train_dl: # TODO(Oktai15): remove it in 1.8.0 version if isinstance(self._train_dl.dataset, MelAudioDataset): self.input_as_mel = True elif isinstance(self._train_dl.dataset, VocoderDataset): self.input_as_mel = self._train_dl.dataset.load_precomputed_mel self.automatic_optimization = False
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") if "dataloader_params" not in cfg or not isinstance( cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True") with open_dict(cfg.dataloader_params): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to False!!!" ) elif not shuffle_should_be and cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to True!!!") kwargs_dict = {} if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset": kwargs_dict["parser"] = self.parser dataset = instantiate(cfg.dataset, **kwargs_dict) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def __init__(self, digi_dataset_config, true_hits_h5file, digi_truth_mapping_file, valid_parents=(1, 2, 3), parent_type="max", transforms=True, is_distributed=False): """ Args: digi_dataset_config ... config for dataset for digitized hits true_hits_h5file ... path to h5 dataset file for true hits digi_truth_mapping_file ... path to file with a pickled list mapping digitized hit events to true hit events valid_parents ... valid ID values for hit parents """ self.digi_dataset = instantiate(digi_dataset_config, is_distributed=is_distributed) self.mpmt_positions = self.digi_dataset.mpmt_positions if transforms: self.transforms = self.digi_dataset.transforms self.digi_dataset.transforms = None else: self.transforms = None self.truth_dataset = H5TrueDataset(true_hits_h5file, transforms=None, digitize_hits=False) with open(digi_truth_mapping_file, 'rb') as f: self.digi_truth_mapping = pickle.load(f) self.valid_parents = np.array(valid_parents) if parent_type == "only": self.get_digi_hit_parent = self.get_digi_hit_only_parent elif parent_type == "max": self.get_digi_hit_parent = self.get_digi_hit_max_parent
def train_dataloader(self): train_ds = instantiate(self.dataset_conf.train) train_dl = DataLoader(train_ds, self.train_conf.batch_size, shuffle=True, num_workers=self.hparams['num_workers']) return train_dl
def test_dataloader(self): test_conf = self.test_conf test_ds = instantiate(self.dataset_conf.test) test_dl = DataLoader(test_ds, test_conf.batch_size, num_workers=self.hparams['num_workers']) return test_dl
def test_instantiate_classes(classname: str, params: Any, args: Any, kwargs: Any, expected: Any) -> None: full_class = f"{MODULE_NAME}.generated.{classname}Conf" schema = OmegaConf.structured(get_class(full_class)) cfg = OmegaConf.merge(schema, params) obj = instantiate(config=cfg, *args, **kwargs) assert obj == expected
def create_simple_dataset(conf, transforms): # type: (DictConfig, DictConfig) -> JustImages transforms = T.Compose([instantiate(v) for k, v in transforms.items()]) ds = JustImages(conf.root, extensions=tuple(conf.extensions), transform=transforms) return ds
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") # TODO if "dataloader_params" not in cfg or not isinstance( cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") # TODO if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True") with open_dict(cfg["dataloader_params"]): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to False!!!" ) elif not shuffle_should_be and cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to True!!!") dataset = instantiate(cfg.dataset) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def __init__(self, cfg): ''' 全結合ネットワークのAutoEncoder Parameters: -------------------------------------- config: 設定ファイルから読み込んだ定数リスト ネットワーク構造などのパラメタの設定も設定ファイルで基本行う ''' super(AutoEncoder, self).__init__() self.config = cfg if 'input_shape' in cfg.data: self.n_features = cfg.data.input_shape[0] if len(cfg.data.feature_cols) != self.n_features: self.n_features = len(cfg.data.feature_cols) else: self.n_features = len(cfg.data.feature_cols) #self.config.data.input_shape = [self.n_features] self.n_timesteps = cfg.data.window_size self.hidden_size = OmegaConf.to_container(cfg.net.hidden_size) self.z_dim = cfg.net.z_dim self.dropout = cfg.net.dropout self.anomaly_scores = None self.recon_x = None act_f = None if 'act_f' in self.config.net: act_f = eval(self.config.net.act_f) ''' 入力を時間(time)軸にするか、特徴量(features)軸にするか ''' if 'input_vec' in self.config.data.keys(): if self.config.data.input_vec == 'time': self.encoder = Encoder(self.n_timesteps, self.hidden_size, self.z_dim, self.n_features, self.dropout) self.decoder = Decoder(self.n_timesteps, self.hidden_size, self.z_dim, self.n_features, self.dropout) else: self.encoder = Encoder(self.n_features, self.hidden_size, self.z_dim, self.n_timesteps, self.dropout) self.decoder = Decoder(self.n_features, self.hidden_size, self.z_dim, self.n_timesteps, self.dropout) else: raise ValueError(f'config[\'data\'] has not \'input_vec\'') self.loss = F.mse_loss if 'loss_f' in self.config.net: self.loss = eval(self.config.net.loss_f) # instantiate のほうがいい? self.optimizer = instantiate(self.config.optimizer, params=self.parameters()) self.save_hyperparameters( OmegaConf.to_container(self.config.net, resolve=True))
def __init__(self, config: Optional[DictConfig] = None) -> None: self.callbacks = [] from hydra.utils import instantiate if config is not None and OmegaConf.select(config, "hydra.callbacks"): for params in config.hydra.callbacks.values(): self.callbacks.append(instantiate(params))
def run_scene_optimizer(args) -> None: """ Run GTSFM over images from an Argoverse vehicle log""" with initialize_config_module(config_module="gtsfm.configs"): # config is relative to the gtsfm module cfg = compose(config_name="default_lund_door_set1_config.yaml") scene_optimizer: SceneOptimizer = instantiate(cfg.SceneOptimizer) loader = ArgoverseDatasetLoader( dataset_dir=args.dataset_dir, log_id=args.log_id, stride=args.stride, max_num_imgs=args.max_num_imgs, max_lookahead_sec=args.max_lookahead_sec, camera_name=args.camera_name, ) sfm_result_graph = scene_optimizer.create_computation_graph( len(loader), loader.get_valid_pairs(), loader.create_computation_graph_for_images(), loader.create_computation_graph_for_intrinsics(), use_intrinsics_in_verification=True, gt_pose_graph=loader.create_computation_graph_for_poses(), ) # create dask client cluster = LocalCluster(n_workers=2, threads_per_worker=4) with Client(cluster), performance_report(filename="dask-report.html"): sfm_result = sfm_result_graph.compute() assert isinstance(sfm_result, GtsfmData) scene_avg_reproj_error = sfm_result.get_scene_avg_reprojection_error() logger.info('Scene avg reproj error: {}'.format( str(np.round(scene_avg_reproj_error, 3))))
def build_optimizer(conf: DictConfig, model: nn.Module) -> Optimizer: parameters = model.parameters() p = conf.params if 'weight_decay' in p and p.weight_decay > 0: parameters = add_weight_decay(model, p.weight_decay) p.weight_decay = 0.0 return instantiate(conf, parameters)
def setup_ema(conf: DictConfig, model: nn.Module, device=None, master_node=False): ema = conf.smoothing model_ema = None def _update(): pass if master_node and ema.enabled: model_ema = instantiate(conf.model) if not ema.use_cpu: model_ema = model_ema.to(device) model_ema.load_state_dict(model.state_dict()) model_ema.requires_grad_(False) beta = 1 - ema.alpha ** ema.interval_it def _update(): states = itertools.chain( zip(model_ema.parameters(), model.parameters()), zip(model_ema.buffers(), model.buffers())) with torch.no_grad(): for t_ema, t in states: # filter out *.bn1.num_batches_tracked if t.dtype != torch.int64: t = t.to(dtype=t_ema.dtype, device=t_ema.device) t_ema.lerp_(t, beta) return model_ema, _update
def __init__(self, cfg: Config) -> None: super().__init__() # type: ignore self.logger: Union[LoggerCollection, WandbLogger, Any] self.wandb: Run self.cfg = cfg self.model: LightningModule = instantiate(self.cfg.experiment.model, self.cfg) self.criterion = MSELoss() # Metrics self.train_mse = MeanSquaredError() self.train_mae = MeanAbsoluteError() self.val_mse = MeanSquaredError() self.val_mae = MeanAbsoluteError() self.test_mse = MeanSquaredError() self.test_mae = MeanAbsoluteError() self.test_results = [] train_params = self.cfg.experiment.synop_train_features target_param = self.cfg.experiment.target_parameter all_params = add_param_to_train_params(train_params, target_param) feature_names = list(list(zip(*all_params))[1]) self.target_param_index = [x for x in feature_names].index(target_param)
def test_instantiate_with_missing_values(self, tmpdir): """Check if raising error with missing values.""" ToyModel.configen(tmpdir) config = OmegaConf.load(Path(tmpdir).joinpath('model/toy.yaml')) validate(config, ModelConfig) with pytest.raises(MissingMandatoryValue): _ = instantiate(config)
def my_app(cfg: Config) -> None: if OmegaConf.get_type(cfg.model) is MlpConfig: mlp_cfg = cast(MlpConfig, cfg.model) print("using MLP") print(f"{mlp_cfg.layers=}") print(f"{mlp_cfg.hidden_units=}") elif OmegaConf.get_type(cfg.model) is SVMConfig: svm_cfg = cast(SVMConfig, cfg.model) print("using SVM") print(f"{svm_cfg.kernel=}") print(f"{svm_cfg.C=}") print() data_dir: Path = instantiate(cfg.dataset.dir) print(data_dir) if OmegaConf.get_type(cfg.dataset) is AdultConfig: adult_cfg = cast(AdultConfig, cfg.dataset) print("using Adult dataset") print(f"{adult_cfg.drop_native=}") elif OmegaConf.get_type(cfg.dataset) is CmnistConfig: cmnist_cfg = cast(CmnistConfig, cfg.dataset) print("using CMNIST dataset") print(f"{cmnist_cfg.padding=}") print() print(f"{cfg.seed=}") print(f"{cfg.use_wandb=}") print(f"{cfg.data_pcnt=}") print() print("Config as flat dictionary:") print(flatten(OmegaConf.to_container(cfg, enum_to_str=True)))
def test_interpolation_accessing_parent_deprecated(input_conf: Any, passthrough: Dict[str, Any], expected: Any, recwarn: Any) -> Any: input_conf = OmegaConf.create(input_conf) obj = utils.instantiate(input_conf.node, **passthrough) assert obj == expected
def run_scene_optimizer() -> None: """ """ with initialize_config_module(config_module="gtsfm.configs"): # config is relative to the gtsfm module cfg = compose(config_name="default_lund_door_set1_config.yaml") scene_optimizer: SceneOptimizer = instantiate(cfg.SceneOptimizer) loader = OlssonLoader(os.path.join(DATA_ROOT, "set1_lund_door"), image_extension="JPG") sfm_result_graph = scene_optimizer.create_computation_graph( num_images=len(loader), image_pair_indices=loader.get_valid_pairs(), image_graph=loader.create_computation_graph_for_images(), camera_intrinsics_graph=loader. create_computation_graph_for_intrinsics(), gt_pose_graph=loader.create_computation_graph_for_poses(), ) # create dask client cluster = LocalCluster(n_workers=2, threads_per_worker=4) with Client(cluster), performance_report(filename="dask-report.html"): sfm_result = sfm_result_graph.compute() assert isinstance(sfm_result, GtsfmData)
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) super().__init__(cfg=cfg, trainer=trainer) schema = OmegaConf.structured(WaveglowConfig) # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) # Ensure passed cfg is compliant with schema OmegaConf.merge(cfg, schema) self.pad_value = self._cfg.preprocessor.params.pad_value self.sigma = self._cfg.sigma self.audio_to_melspec_precessor = instantiate(self._cfg.preprocessor) self.model = UniGlowModule( self._cfg.uniglow.n_mel_channels, self._cfg.uniglow.n_flows, self._cfg.uniglow.n_group, self._cfg.uniglow.n_wn_channels, self._cfg.uniglow.n_wn_layers, self._cfg.uniglow.wn_kernel_size, self.get_upsample_factor(), ) self.mode = OperationMode.infer self.loss = UniGlowLoss(self._cfg.uniglow.stft_loss_coef) self.removed_weightnorm = False
def _loader(cfg): dataset = instantiate(cfg.dataset) return torch.utils.data.DataLoader( # noqa dataset=dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params, )