def __init__(self, cfg):

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.dtype = torch.float32
        self.eps = 1e-8

        self.stft_module = STFTModule(cfg['stft_params'], self.device)
        self.stft_module_ex1 = STFTModule(cfg['stft_params_ex1'], self.device)
        self.stft_module_ex2 = STFTModule(cfg['stft_params_ex2'], self.device)

        self.train_data_num = cfg['train_data_num']
        self.valid_data_num = cfg['valid_data_num']
        self.sample_len = cfg['sample_len']
        self.epoch_num = cfg['epoch_num']
        self.train_batch_size = cfg['train_batch_size']
        self.valid_batch_size = cfg['valid_batch_size']

        self.train_full_data_num = cfg['train_full_data_num']
        self.valid_full_data_num = cfg['valid_full_data_num']
        self.save_path = cfg['save_path']

        self.train_dataset = VoicebankDemandDataset(
            data_num=self.train_data_num,
            full_data_num=self.train_full_data_num,
            sample_len=self.sample_len,
            folder_type='train',
            shuffle=True,
            device=self.device,
            augmentation=True)

        self.valid_dataset = VoicebankDemandDataset(
            data_num=self.valid_data_num,
            full_data_num=self.valid_full_data_num,
            sample_len=self.sample_len,
            folder_type='validation',
            shuffle=True,
            device=self.device,
            augmentation=False)

        self.train_data_loader = FastDataLoader(
            self.train_dataset, batch_size=self.train_batch_size, shuffle=True)

        self.valid_data_loader = FastDataLoader(
            self.valid_dataset, batch_size=self.valid_batch_size, shuffle=True)

        self.model = FeatExtractorBlstm_pp_v3(cfg['dnn_cfg']).to(self.device)
        self.criterion = Clip_SDR()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.early_stopping = EarlyStopping(patience=10)
    def __init__(self, cfg):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.dtype = torch.float32
        self.eps = 1e-4
        self.eval_path = cfg['eval_path']

        self.model = CNNOpenUnmix_p2(cfg['dnn_cfg']).to(self.device)
        self.model.eval()
        self.model.load_state_dict(
            torch.load(self.eval_path, map_location=self.device))

        self.stft_module = STFTModule(cfg['stft_params'], self.device)
        self.stft_module_ex2 = STFTModule(cfg['stft_params_ex2'], self.device)

        self.test_data_num = cfg['test_data_num']
        self.test_batch_size = cfg['test_batch_size']
        self.sample_len = cfg['sample_len']

        self.test_dataset = VoicebankDemandDataset(data_num=self.test_data_num,
                                                   sample_len=self.sample_len,
                                                   folder_type='test',
                                                   shuffle=False)

        self.test_data_loader = FastDataLoader(self.test_dataset,
                                               batch_size=self.test_batch_size,
                                               shuffle=False)

        self.stoi_list = np.array([])
        self.pesq_list = np.array([])
        self.si_sdr_list = np.array([])
        self.si_sdr_improve_list = np.array([])
    def __init__(self, cfg):

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.dtype = torch.float32
        self.eps = 1e-8

        self.stft_module = STFTModule(cfg['stft_params'], self.device)
        self.train_data_num = cfg['train_data_num']
        self.valid_data_num = cfg['valid_data_num']
        self.sample_len = cfg['sample_len']
        self.epoch_num = cfg['epoch_num']
        self.train_batch_size = cfg['train_batch_size']
        self.valid_batch_size = cfg['valid_batch_size']
        self.train_full_data_num = cfg['train_full_data_num']
        self.valid_full_data_num = cfg['valid_full_data_num']

        self.train_dataset = VoicebankDemandDataset(
            data_num=self.train_data_num,
            full_data_num=self.train_full_data_num,
            sample_len=self.sample_len,
            folder_type='train')

        self.valid_dataset = VoicebankDemandDataset(
            data_num=self.valid_data_num,
            full_data_num=self.valid_full_data_num,
            sample_len=self.sample_len,
            folder_type='validation')

        self.train_data_loader = FastDataLoader(
            self.train_dataset, batch_size=self.train_batch_size, shuffle=True)

        self.valid_data_loader = FastDataLoader(
            self.valid_dataset, batch_size=self.valid_batch_size, shuffle=True)

        self.model = DemandUNet().to(self.device)
        self.criterion = PSA()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.save_path = cfg['save_path']
        self.early_stopping = EarlyStopping(patience=10)