def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_devices super().__init__(cfg=cfg, trainer=trainer) self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor) self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder) self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder) if 'angular' in cfg.decoder and cfg.decoder['angular']: logging.info("loss is Angular Softmax") scale = cfg.loss.scale margin = cfg.loss.margin self.loss = AngularSoftmaxLoss(scale=scale, margin=margin) else: logging.info("loss is Softmax-CrossEntropy") self.loss = CELoss() self.task = None self._accuracy = TopKClassificationAccuracy(top_k=[1]) self.labels = None if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: self.spec_augmentation = EncDecSpeakerLabelModel.from_config_dict(self._cfg.spec_augment) else: self.spec_augmentation = None
def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor) self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder) self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder) if 'angular' in cfg.decoder and cfg.decoder['angular']: logging.info("Training with Angular Softmax Loss") scale = cfg.loss.scale margin = cfg.loss.margin self.loss = AngularSoftmaxLoss(scale=scale, margin=margin) else: logging.info("Training with Softmax-CrossEntropy loss") self.loss = CELoss() self.task = None self._accuracy = TopKClassificationAccuracy(top_k=[1])