コード例 #1
0
def evaluation():
    with torch.no_grad():

        #
        # initialize network
        #
        np.random.seed(0)
        torch.manual_seed(0)

        network_params = 0

        network = EnsembleNetwork(
            filepath_gating=file_gating,
            filepaths_denoising=files_specialists,
            g_hs=hidden_size_gating,
            g_nl=num_layers_gating,
            s_hs=hidden_size_specialist,
            s_nl=num_layers_specialist,
            ct=args.latent_space,
        ).to(device=args.device_id)

        F.write_data(filename=os.path.join(output_directory,
                                           'files_gating.txt'),
                     data=str(file_gating))
        F.write_data(filename=os.path.join(output_directory,
                                           'files_specialist.txt'),
                     data=str(files_specialists))

        with torch.cuda.device(args.device_id):
            torch.cuda.empty_cache()

        if args.latent_space == 'gender':

            te_sisdr = {str(k): 0 for k in C.gender_all}
            for te_gender in C.gender_all:

                te_batch_durations = list()
                te_batch_sisdr = list()
                files_speech = np.random.choice(F.filter_by_gender(
                    te_utterances, te_gender),
                                                size=C.te_batch_size)
                files_noise = np.random.choice(te_noises, size=C.te_batch_size)

                for (i, fs, fn) in zip(range(C.te_batch_size), files_speech,
                                       files_noise):

                    source = F.load_audio(fs,
                                          duration=None,
                                          random_offset=False,
                                          device=args.device_id)
                    noise = F.load_audio(fn,
                                         duration=None,
                                         random_offset=False,
                                         device=args.device_id)
                    min_length = min(len(source), len(noise))
                    stft_frames = ceil(min_length / C.hop_size)
                    source = source[:min_length]
                    noise = noise[:min_length]

                    (x, s, n) = F.mix_signals(source, noise, snr_db=C.snr_all)
                    (X, X_mag) = F.stft(x)

                    X = X.permute(
                        1, 0,
                        2)[:stft_frames]  # (seq_len, num_features, channel)
                    X_mag = X_mag.permute(
                        1, 0)[:stft_frames]  # (seq_len, num_features)
                    X = torch.unsqueeze(X, dim=0)
                    X_mag = torch.unsqueeze(X_mag, dim=0)
                    s = torch.unsqueeze(s, dim=0)
                    x = torch.unsqueeze(x, dim=0)

                    actual_sisdr = float(F.calculate_sisdr(s, x).item())

                    # feed-forward
                    M_hat = network(X_mag)
                    s_hat = F.istft(X, mask=M_hat)

                    te_batch_sisdr.append(
                        (F.calculate_sisdr(s, s_hat,
                                           offset=actual_sisdr).mean().item()))
                    te_batch_durations.append(min_length)

                # store the weighted average results
                te_sisdr[str(te_gender)] = np.average(
                    te_batch_sisdr, weights=te_batch_durations)

        elif args.latent_space == 'snr':

            te_sisdr = {str(k): 0 for k in C.snr_all}
            for te_snr in C.snr_all:

                te_batch_durations = list()
                te_batch_sisdr = list()
                files_speech = np.random.choice(te_utterances,
                                                size=C.te_batch_size)
                files_noise = np.random.choice(te_noises, size=C.te_batch_size)

                for (i, fs, fn) in zip(range(C.te_batch_size), files_speech,
                                       files_noise):

                    source = F.load_audio(fs,
                                          duration=None,
                                          random_offset=False,
                                          device=args.device_id)
                    noise = F.load_audio(fn,
                                         duration=None,
                                         random_offset=False,
                                         device=args.device_id)
                    min_length = min(len(source), len(noise))
                    stft_frames = ceil(min_length / C.hop_size)
                    source = source[:min_length]
                    noise = noise[:min_length]

                    (x, s, n) = F.mix_signals(source, noise, snr_db=te_snr)
                    (X, X_mag) = F.stft(x)

                    X = X.permute(
                        1, 0,
                        2)[:stft_frames]  # (seq_len, num_features, channel)
                    X_mag = X_mag.permute(
                        1, 0)[:stft_frames]  # (seq_len, num_features)
                    X = torch.unsqueeze(X, dim=0)
                    X_mag = torch.unsqueeze(X_mag, dim=0)
                    s = torch.unsqueeze(s, dim=0)
                    x = torch.unsqueeze(x, dim=0)

                    actual_sisdr = float(F.calculate_sisdr(s, x).item())

                    # feed-forward
                    M_hat = network(X_mag)
                    s_hat = F.istft(X, mask=M_hat)

                    te_batch_sisdr.append(
                        (F.calculate_sisdr(s, s_hat,
                                           offset=actual_sisdr).mean().item()))
                    te_batch_durations.append(min_length)

                # store the weighted average results
                te_sisdr[str(te_snr)] = np.average(te_batch_sisdr,
                                                   weights=te_batch_durations)

        te_sisdr['mean'] = np.mean(list(te_sisdr.values()))

        logging.info(json.dumps(te_sisdr, sort_keys=True, indent=4))
        F.write_data(filename=os.path.join(output_directory,
                                           f'test_results.txt'),
                     data=te_sisdr)

    return
コード例 #2
0
ファイル: train_denoising.py プロジェクト: mtxing/sparse_mle
def experiment():

    #
    # initialize network
    #
    np.random.seed(0)
    torch.manual_seed(0)

    network = M.DenoisingNetwork(args.hidden_size,
                                 args.num_layers).to(device=args.device_id)

    network_params = F.count_parameters(network)

    optimizer = torch.optim.Adam(
        params=network.parameters(),
        lr=args.learning_rate,
    )

    criterion = F.loss_sisdr

    F.write_data(filename=os.path.join(output_directory, 'num_parameters.txt'),
                 data=network_params)

    with torch.cuda.device(args.device_id):
        torch.cuda.empty_cache()

    #
    # log experiment configuration
    #
    os.system('cls' if os.name == 'nt' else 'clear')
    logging.info(f'Training Denoising network' + (
        f' specializing in {F.fmt_specialty(args.specialization)} mixtures'
        if args.specialization else '') + '...')
    logging.info(f'\u2022 {args.hidden_size} hidden units')
    logging.info(f'\u2022 {args.num_layers} layers')
    logging.info(f'\u2022 {network_params} learnable parameters')
    logging.info(f'\u2022 {args.learning_rate:.3e} learning rate')
    logging.info(f'Results will be saved in "{output_directory}".')
    logging.info(f'Using GPU device {args.device_id}...')

    #
    # experiment loop
    #
    (iteration, iteration_best) = (0, 0)
    sisdr_best = 0

    while not C.stopping_criteria(iteration, iteration_best):

        network.train()
        np.random.seed(iteration)
        torch.manual_seed(iteration)

        # training
        for batch_index in range(100):

            # forward propagation
            batch = F.generate_batch(
                np.random.choice(tr_utterances, size=C.tr_batch_size),
                np.random.choice(tr_noises, size=C.tr_batch_size),
                mixture_snr=tr_snr,
                device=args.device_id,
            )
            M_hat = network(batch.X_mag)
            s_hat = F.istft(batch.X, mask=M_hat)

            # backward propagation
            optimizer.zero_grad()
            F.loss_sisdr(batch.s, s_hat).backward()
            torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1e-4)
            optimizer.step()

        network.eval()
        np.random.seed(0)
        torch.manual_seed(0)

        # validation
        with torch.no_grad():

            if args.latent_space == 'gender':

                sisdr_batch = {k: 0 for k in C.gender_all}
                for vl_gender in C.gender_all:

                    vl_filtered_files = F.filter_by_gender(
                        vl_utterances, vl_gender)
                    batch = F.generate_batch(
                        np.random.choice(vl_filtered_files,
                                         size=C.vl_batch_size),
                        np.random.choice(vl_noises, size=C.vl_batch_size),
                        device=args.device_id,
                    )
                    M_hat = network(batch.X_mag)
                    s_hat = F.istft(batch.X, mask=M_hat)
                    sisdr_batch[vl_gender] = float(
                        F.calculate_sisdr(
                            batch.s, s_hat,
                            offset=batch.actual_sisdr).mean().item())

            else:

                sisdr_batch = {k: 0 for k in C.snr_all}
                for vl_snr in C.snr_all:

                    batch = F.generate_batch(
                        np.random.choice(vl_utterances, size=C.vl_batch_size),
                        np.random.choice(vl_noises, size=C.vl_batch_size),
                        mixture_snr=vl_snr,
                        device=args.device_id,
                    )
                    M_hat = network(batch.X_mag)
                    s_hat = F.istft(batch.X, mask=M_hat)
                    sisdr_batch[vl_snr] = float(
                        F.calculate_sisdr(
                            batch.s, s_hat,
                            offset=batch.actual_sisdr).mean().item())

        sisdr_batch['mean'] = np.mean(list(sisdr_batch.values()))

        # print results
        if sisdr_batch['mean' if args.
                       latent_space != 'snr' else tr_snr] > sisdr_best:
            sisdr_best = sisdr_batch['mean' if args.
                                     latent_space != 'snr' else tr_snr]
            iteration_best = iteration

            F.write_data(filename=os.path.join(output_directory,
                                               'validation_sisdr.txt'),
                         data=f'{sisdr_best:%}')
            torch.save(network.state_dict(),
                       os.path.join(output_directory, 'model.pt'))
            checkmark = ' | \033[32m\u2714\033[39m'
        else:
            checkmark = ''

        status = ''
        for (k, v) in sisdr_batch.items():
            status += f'\033[33m{k}: {v:>6.3f} dB\033[39m, '
        ts_end = int(round(time.time())) - ts_start
        status += f'Time Elapsed: {int(ts_end/60)} minutes' + checkmark
        logging.info(status)
        iteration += 1

    return
コード例 #3
0
def evaluation():
    with torch.no_grad():

        sum_num_params = 0

        #
        # initialize gating network
        #
        gating = GatingNetwork(hidden_size_gating, num_layers_gating, len(gender_all)).to(device=args.device_id)
        gating.load_state_dict(torch.load(
            args.state_dict_file_gating, map_location=torch.device(args.device_id))
        )
        gating.eval()
        sum_num_params += F.count_parameters(gating)


        #
        # initialize specialist networks (as a hashed list of networks)
        #
        specialists = {
            i: SpecialistNetwork(hidden_size_specialist, num_layers_specialist).to(device=args.device_id)
            for i in range(len(gender_all))
        }
        for i in range(len(gender_all)):
            assert (re.search(r'gender\_[MF]', args.state_dict_file_specialist) is not None)
            filepath = re.sub(r'gender\_[MF]', F.fmt_gender(gender_all[i]), args.state_dict_file_specialist)
            specialists[i].load_state_dict(torch.load(
                filepath, map_location=torch.device(args.device_id))
            )
            specialists[i].eval()
            sum_num_params += F.count_parameters(specialists[i])


        F.write_data(filename=os.path.join(output_directory, 'num_parameters.txt'),
                     data=sum_num_params)
        with torch.cuda.device(args.device_id):
            torch.cuda.empty_cache()


        #
        # log experiment configuration
        #
        logging.info('All results will be stored in "{}".'.format(
            output_directory))
        logging.info('Testing {} model (with Gating architecture {} and Specialist architecture {}) to denoise {} gendered mixtures...'.format(
            model_name, architecture_gating, architecture_specialist, gender_all))
        logging.info('Using GPU device {}...'.format(
            args.device_id))


        fields = ['snr_val','num_mixtures','sdr','sisdr','mse','bce','accuracy']

        #
        # validation
        #
        results_validation = []
        np.random.seed(0)
        torch.manual_seed(0)
        for gender_val in gender_all:

            # construct a batch
            batch = F.generate_batch(
                vl_utterances, vl_noises,
                batch_size=vl_batch_size,
                gender=gender_val,
                device=args.device_id,
            )
            Y = batch.index_gender

            # compute batch-wise specialist probabilities
            Y_hat = gating(batch.X_mag)

            # pick the best specialist to apply to the whole batch (based on batch probabilities sum)
            k = int(Y_hat.sum(dim=0).argmax().item())

            # apply the best specialist to the entire batch
            M_hat = specialists[k](batch.X_mag)
            s_hat = F.istft(batch.X, mask=M_hat)

            results_validation.append([
                gender_val,
                vl_batch_size,
                float(F.calculate_sdr(batch.s, s_hat, offset=batch.actual_sdr).mean().item()),
                float(F.calculate_sisdr(batch.s, s_hat, offset=batch.actual_sisdr).mean().item()),
                float(F.calculate_mse(batch.M, M_hat).item()),
                float(F.calculate_bce(batch.M, M_hat).item()),
                float(F.calculate_accuracy(Y, Y_hat)),
            ])
            status = (
                f'Validation Data (Gender={gender_val}) -- ' + \
                f'SDR: {results_validation[-1][2]:>6.3f} dB, ' + \
                f'\033[33mSISDR: {results_validation[-1][3]:>6.3f} dB\033[39m, ' + \
                f'MSE: {results_validation[-1][4]:>6.3f}, ' + \
                f'BCE: {results_validation[-1][5]:>6.3f}, ' + \
                f'Accuracy: {results_validation[-1][6]:>6.3f}'
            )
            logging.info(status)
        F.write_table(filename=os.path.join(output_directory, f'validation_results.txt'),
                      table_data=results_validation, headers=fields)


        #
        # testing
        #
        results_testing = []

        for gender_val in gender_all:
            np.random.seed(0)
            torch.manual_seed(0)
            te_utterances_filtered = te_utterances[np.array([(F.get_gender(row) in gender_val) for row in te_utterances])]
            files_speech = np.random.choice(te_utterances_filtered, size=te_batch_size)
            files_noise = np.random.choice(te_noises, size=te_batch_size)
            te_m_durations = list()
            te_m_sdr = list()
            te_m_sisdr = list()
            te_m_mse = list()
            te_m_bce = list()
            te_m_accuracy = list()
            for (i, fs, fn) in zip(range(te_batch_size), files_speech, files_noise):

                source = F.load_audio(fs, duration=None, random_offset=False, device=args.device_id)
                noise = F.load_audio(fn, duration=None, random_offset=False, device=args.device_id)
                min_length = min(len(source), len(noise))
                stft_frames = ceil(min_length/hop_size)
                source = source[:min_length]
                noise = noise[:min_length]

                (x, s, n) = F.mix_signals(source, noise, snr_db=snr_all)
                (S, S_mag) = F.stft(s)
                (N, N_mag) = F.stft(n)
                (X, X_mag) = F.stft(x)
                (M) = F.calculate_masking_target(S_mag, N_mag)

                X = X.permute(1, 0, 2)[:stft_frames] # (seq_len, num_features, channel)
                S_mag = S_mag.permute(1, 0)[:stft_frames]  # (seq_len, num_features)
                N_mag = N_mag.permute(1, 0)[:stft_frames]  # (seq_len, num_features)
                X_mag = X_mag.permute(1, 0)[:stft_frames]  # (seq_len, num_features)
                M = M.permute(1, 0)[:stft_frames]  # (seq_len, num_features)

                actual_sdr = float(F.calculate_sdr(s, x).item())
                actual_sisdr = float(F.calculate_sisdr(s, x).item())

                gender_index = int(F.get_gender(fs)=='F')
                Y = torch.zeros(1, len(gender_all), device=args.device_id)
                Y[..., gender_index] = 1

                # add a fake batch axis to everything
                x = torch.unsqueeze(x, dim=0)
                s = torch.unsqueeze(s, dim=0)
                n = torch.unsqueeze(n, dim=0)
                S = torch.unsqueeze(S, dim=0)
                S_mag = torch.unsqueeze(S_mag, dim=0)
                N = torch.unsqueeze(N, dim=0)
                N_mag = torch.unsqueeze(N_mag, dim=0)
                X = torch.unsqueeze(X, dim=0)
                X_mag = torch.unsqueeze(X_mag, dim=0)
                M = torch.unsqueeze(M, dim=0)

                # compute batch-wise specialist probabilities
                Y_hat = gating(X_mag)

                # pick the best specialist to apply to the whole batch (based on batch probabilities sum)
                k = int(Y_hat.sum(dim=0).argmax().item())

                # apply the best specialist to the entire batch
                M_hat = specialists[k](X_mag)
                s_hat = F.istft(X, mask=M_hat)

                te_m_sdr.append(F.calculate_sdr(s, s_hat, offset=actual_sdr).mean().item())
                te_m_sisdr.append(F.calculate_sisdr(s, s_hat, offset=actual_sisdr).mean().item())
                te_m_mse.append(F.calculate_mse(M, M_hat).item())
                te_m_bce.append(F.calculate_bce(M, M_hat).item())
                te_m_accuracy.append(float(torch.prod(Y==torch.round(Y_hat), dim=-1).sum().item()/float(len(Y))))
                te_m_durations.append(min_length)

            # store the weighted average results
            results_testing.append([
                gender_val,
                te_batch_size,
                np.average(te_m_sdr, weights=te_m_durations),
                np.average(te_m_sisdr, weights=te_m_durations),
                np.average(te_m_mse, weights=te_m_durations),
                np.average(te_m_bce, weights=te_m_durations),
                np.average(te_m_accuracy, weights=te_m_durations),
            ])
            status = (
                f'Test Data (Gender={gender_val}) -- ' + \
                f'SDR: {results_testing[-1][2]:>6.3f} dB, ' + \
                f'\033[33mSISDR: {results_testing[-1][3]:>6.3f} dB\033[39m, ' + \
                f'MSE: {results_testing[-1][4]:>6.3f}, ' + \
                f'BCE: {results_testing[-1][5]:>6.3f}, ' + \
                f'Accuracy: {results_testing[-1][6]:>6.3f}'
            )
            logging.info(status)
        F.write_table(filename=os.path.join(output_directory, f'test_results.txt'),
                      table_data=results_testing, headers=fields)
    return
コード例 #4
0
                                 duration=None)
            min_length = min(len(source), len(noise))
            (x, s, n) = F.mix_signals(source[:min_length],
                                      noise[:min_length],
                                      snr_db=mixture_snr)
            (S, S_mag) = F.stft(s)
            (N, N_mag) = F.stft(n)
            (X, X_mag) = F.stft(x)
            (M) = F.calculate_masking_target(S_mag, N_mag)
            X = X.permute(1, 0, 2)
            S_mag = S_mag.permute(1, 0)
            N_mag = N_mag.permute(1, 0)
            X_mag = X_mag.permute(1, 0)
            M = M.permute(1, 0)
            actual_sdr = float(F.calculate_sdr(s, x).item())
            actual_sisdr = float(F.calculate_sisdr(s, x).item())

            # inference
            M = network(X_mag.unsqueeze(0))
            y = F.istft(X.unsqueeze(0), mask=M)
            (Y, Y_mag) = F.stft(y)
            y = y.squeeze()
            output_sdr = float(F.calculate_sdr(s, y).item())
            output_sisdr = float(F.calculate_sisdr(s, y).item())

            # convert everything back into numpy types
            max_amplitude = 1e-30 + 1.15 * float(
                max(s.max(), n.max(), x.max(), y.max()))
            s = s.detach().cpu().numpy() / max_amplitude
            n = n.detach().cpu().numpy() / max_amplitude
            x = x.detach().cpu().numpy() / max_amplitude
コード例 #5
0
ファイル: test_denoising.py プロジェクト: mtxing/sparse_mle
def evaluation():
    with torch.no_grad():

        #
        # initialize network
        #
        np.random.seed(0)
        torch.manual_seed(0)

        network = DenoisingNetwork(args.hidden_size,
                                   args.num_layers).to(device=args.device_id)

        network_params = F.count_parameters(network)

        network.load_state_dict(torch.load(
            args.state_dict_file,
            map_location=torch.device(args.device_id),
        ),
                                strict=True)
        network.eval()

        F.write_data(filename=os.path.join(output_directory,
                                           'num_parameters.txt'),
                     data=network_params)

        with torch.cuda.device(args.device_id):
            torch.cuda.empty_cache()

        te_sisdr = dict()

        if args.latent_space in ('gender', 'all'):
            np.random.seed(0)
            torch.manual_seed(0)

            for te_gender in C.gender_all:

                logging.info(
                    f'Now testing model with {te_gender}-gender inputs...')

                te_batch_durations = list()
                te_batch_sisdr = list()
                files_speech = np.random.choice(F.filter_by_gender(
                    te_utterances, te_gender),
                                                size=C.te_batch_size)
                files_noise = np.random.choice(te_noises, size=C.te_batch_size)

                for (i, fs, fn) in zip(range(C.te_batch_size), files_speech,
                                       files_noise):

                    source = F.load_audio(fs,
                                          duration=None,
                                          random_offset=False,
                                          device=args.device_id)
                    noise = F.load_audio(fn,
                                         duration=None,
                                         random_offset=False,
                                         device=args.device_id)
                    min_length = min(len(source), len(noise))
                    stft_frames = ceil(min_length / C.hop_size)
                    source = source[:min_length]
                    noise = noise[:min_length]

                    (x, s, n) = F.mix_signals(source, noise, snr_db=C.snr_all)
                    (X, X_mag) = F.stft(x)

                    X = X.permute(
                        1, 0,
                        2)[:stft_frames]  # (seq_len, num_features, channel)
                    X_mag = X_mag.permute(
                        1, 0)[:stft_frames]  # (seq_len, num_features)
                    X = torch.unsqueeze(X, dim=0)
                    X_mag = torch.unsqueeze(X_mag, dim=0)
                    s = torch.unsqueeze(s, dim=0)
                    x = torch.unsqueeze(x, dim=0)

                    actual_sisdr = float(F.calculate_sisdr(s, x).item())

                    # feed-forward
                    M_hat = network(X_mag)
                    s_hat = F.istft(X, mask=M_hat)

                    te_batch_sisdr.append(
                        (F.calculate_sisdr(s, s_hat,
                                           offset=actual_sisdr).mean().item()))
                    te_batch_durations.append(min_length)

                # store the weighted average results
                te_sisdr[str(te_gender)] = np.average(
                    te_batch_sisdr, weights=te_batch_durations)

            te_sisdr['mean_gender'] = np.mean(
                [te_sisdr[str(x)] for x in C.gender_all])

        if args.latent_space in ('snr', 'all'):
            np.random.seed(0)
            torch.manual_seed(0)

            for te_snr in C.snr_all:

                logging.info(
                    f'Now testing model with {te_snr} dB mixture SDR inputs...'
                )

                te_batch_durations = list()
                te_batch_sisdr = list()
                files_speech = np.random.choice(te_utterances,
                                                size=C.te_batch_size)
                files_noise = np.random.choice(te_noises, size=C.te_batch_size)

                for (i, fs, fn) in zip(range(C.te_batch_size), files_speech,
                                       files_noise):

                    source = F.load_audio(fs,
                                          duration=None,
                                          random_offset=False,
                                          device=args.device_id)
                    noise = F.load_audio(fn,
                                         duration=None,
                                         random_offset=False,
                                         device=args.device_id)
                    min_length = min(len(source), len(noise))
                    stft_frames = ceil(min_length / C.hop_size)
                    source = source[:min_length]
                    noise = noise[:min_length]

                    (x, s, n) = F.mix_signals(source, noise, snr_db=te_snr)
                    (X, X_mag) = F.stft(x)

                    X = X.permute(
                        1, 0,
                        2)[:stft_frames]  # (seq_len, num_features, channel)
                    X_mag = X_mag.permute(
                        1, 0)[:stft_frames]  # (seq_len, num_features)
                    X = torch.unsqueeze(X, dim=0)
                    X_mag = torch.unsqueeze(X_mag, dim=0)
                    s = torch.unsqueeze(s, dim=0)
                    x = torch.unsqueeze(x, dim=0)

                    actual_sisdr = float(F.calculate_sisdr(s, x).item())

                    # feed-forward
                    M_hat = network(X_mag)
                    s_hat = F.istft(X, mask=M_hat)

                    te_batch_sisdr.append(
                        (F.calculate_sisdr(s, s_hat,
                                           offset=actual_sisdr).mean().item()))
                    te_batch_durations.append(min_length)

                # store the weighted average results
                te_sisdr[str(te_snr)] = np.average(te_batch_sisdr,
                                                   weights=te_batch_durations)

            te_sisdr['mean_sisdr'] = np.mean(
                [te_sisdr[str(x)] for x in C.snr_all])

        logging.info(json.dumps(te_sisdr, sort_keys=True, indent=4))
        F.write_data(filename=os.path.join(output_directory,
                                           f'test_results.txt'),
                     data=te_sisdr)

    return
コード例 #6
0
def experiment():

    #
    # ensure reproducibility
    #
    np.random.seed(0)
    torch.manual_seed(0)

    #
    # initialize network
    #
    network = EnsembleNetwork(
        filepath_gating=file_gating,
        filepaths_denoising=files_specialists,
        g_hs=hidden_size_gating,
        g_nl=num_layers_gating,
        s_hs=hidden_size_specialist,
        s_nl=num_layers_specialist,
        ct=args.latent_space,
    ).to(device=args.device_id)

    optimizer = torch.optim.Adam(
        params=network.parameters(),
        lr=args.learning_rate,
    )

    network_params = F.count_parameters(network.gating) + F.count_parameters(
        network.specialists[0])

    F.write_data(filename=os.path.join(output_directory, 'num_parameters.txt'),
                 data=network_params)
    F.write_data(filename=os.path.join(output_directory, 'files_gating.txt'),
                 data=file_gating)
    F.write_data(filename=os.path.join(output_directory,
                                       'files_specialist.txt'),
                 data=files_specialists)
    with torch.cuda.device(args.device_id):
        torch.cuda.empty_cache()

    #
    # log experiment configuration
    #
    os.system('cls' if os.name == 'nt' else 'clear')
    logging.info(
        f'Training Ensemble network composed of {args.latent_space} specialists...'
    )
    logging.info(f'\u2022 {architecture_gating} gating architecture')
    logging.info(f'\u2022 {architecture_specialist} specialist architecture')
    logging.info(
        f'\u2022 Softmax annealing strategy = {args.softmax_annealing if args.softmax_annealing else None}'
    )
    logging.info(f'\u2022 {network_params} learnable parameters')
    logging.info(f'\u2022 {args.learning_rate:.3e} learning rate')
    logging.info(f'Results will be saved in "{output_directory}".')
    logging.info(f'Using GPU device {args.device_id}...')

    # softmax_annealing
    # experiment loop
    #
    (iteration, iteration_best) = (0, 0)
    sisdr_best = 0

    while not C.stopping_criteria(iteration, iteration_best):

        network.train()
        np.random.seed(iteration)
        torch.manual_seed(iteration)

        # training
        for batch_index in range(100):

            # forward propagation
            batch = F.generate_batch(
                np.random.choice(tr_utterances, size=C.tr_batch_size),
                np.random.choice(tr_noises, size=C.tr_batch_size),
                device=args.device_id,
            )
            M_hat = network(batch.X_mag, args.softmax_annealing)
            s_hat = F.istft(batch.X, mask=M_hat)

            # backward propagation
            optimizer.zero_grad()
            F.loss_sisdr(batch.s, s_hat).backward()
            torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1e-4)
            optimizer.step()

        network.eval()
        np.random.seed(0)
        torch.manual_seed(0)

        # validation
        with torch.no_grad():

            if args.latent_space == 'gender':

                sisdr_batch = {k: 0 for k in C.gender_all}
                for vl_gender in C.gender_all:

                    vl_filtered_files = F.filter_by_gender(
                        vl_utterances, vl_gender)
                    batch = F.generate_batch(
                        np.random.choice(vl_filtered_files,
                                         size=C.vl_batch_size),
                        np.random.choice(vl_noises, size=C.vl_batch_size),
                        device=args.device_id,
                    )
                    M_hat = network(batch.X_mag)
                    s_hat = F.istft(batch.X, mask=M_hat)
                    sisdr_batch[vl_gender] = float(
                        F.calculate_sisdr(
                            batch.s, s_hat,
                            offset=batch.actual_sisdr).mean().item())

            else:

                sisdr_batch = {k: 0 for k in C.snr_all}
                for vl_snr in C.snr_all:

                    batch = F.generate_batch(
                        np.random.choice(vl_utterances, size=C.vl_batch_size),
                        np.random.choice(vl_noises, size=C.vl_batch_size),
                        mixture_snr=vl_snr,
                        device=args.device_id,
                    )
                    M_hat = network(batch.X_mag)
                    s_hat = F.istft(batch.X, mask=M_hat)
                    sisdr_batch[vl_snr] = float(
                        F.calculate_sisdr(
                            batch.s, s_hat,
                            offset=batch.actual_sisdr).mean().item())

        sisdr_batch['mean'] = np.mean(list(sisdr_batch.values()))

        # print results
        if sisdr_batch['mean'] > sisdr_best:
            sisdr_best = sisdr_batch['mean']
            iteration_best = iteration

            F.write_data(filename=os.path.join(output_directory,
                                               'validation_sisdr.txt'),
                         data=f'{sisdr_best:%}')
            torch.save(network.state_dict(),
                       os.path.join(output_directory, 'model.pt'))
            checkmark = ' | \033[32m\u2714\033[39m'
        else:
            checkmark = ''

        status = ''
        for (k, v) in sisdr_batch.items():
            status += f'\033[33m{k}: {v:>6.3f} dB\033[39m, '
        ts_end = int(round(time.time())) - ts_start
        status += f'Time Elapsed: {int(ts_end/60)} minutes' + checkmark
        logging.info(status)
        logging.info(
            f'Network # of forwards: {network.num_forwards} \t Alpha: {network.alpha}'
        )
        iteration += 1

    return os.path.join(output_directory, 'model.pt')