Exemplo n.º 1
0
from torch import optim
import logging
import matplotlib.pyplot as plt
import shutil
import json
import tqdm
import glob
import numpy as np
import termtables

# ----------------------------------------------------
# ------------------- SETTING UP ---------------------
# ----------------------------------------------------

# seed this recipe for reproducibility
utils.seed(0)

# set up logging
logging.basicConfig(
    format=
    '%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
    datefmt='%Y-%m-%d:%H:%M:%S',
    level=logging.INFO)

# make sure this is set to WHAM root directory
WHAM_ROOT = os.getenv("WHAM_ROOT")
CACHE_ROOT = os.getenv("CACHE_ROOT")
NUM_WORKERS = multiprocessing.cpu_count() // 4
OUTPUT_DIR = os.path.expanduser('~/.nussl/recipes/wham_mi/run2')
RESULTS_DIR = os.path.join(OUTPUT_DIR, 'results')
MODEL_PATH = os.path.join(OUTPUT_DIR, 'checkpoints', 'best.model.pth')
Exemplo n.º 2
0
def test_gradients(mix_source_folder):
    os.makedirs('tests/local/', exist_ok=True)

    utils.seed(0)

    tfms = datasets.transforms.Compose([
        datasets.transforms.GetAudio(),
        datasets.transforms.PhaseSensitiveSpectrumApproximation(),
        datasets.transforms.MagnitudeWeights(),
        datasets.transforms.ToSeparationModel(),
        datasets.transforms.GetExcerpt(50),
        datasets.transforms.GetExcerpt(3136,
                                       time_dim=1,
                                       tf_keys=['mix_audio', 'source_audio'])
    ])
    dataset = datasets.MixSourceFolder(mix_source_folder, transform=tfms)

    # create the model, based on the first item in the dataset
    # second bit of the shape is the number of features
    n_features = dataset[0]['mix_magnitude'].shape[1]

    # make some configs
    names = [
        'dpcl', 'mask_inference_l1', 'mask_inference_mse_loss', 'chimera',
        'open_unmix', 'end_to_end', 'dual_path'
    ]
    config_has_batch_norm = ['open_unmix', 'dual_path']
    configs = [
        ml.networks.builders.build_recurrent_dpcl(
            n_features,
            50,
            1,
            True,
            0.0,
            20, ['sigmoid'],
            normalization_class='InstanceNorm'),
        ml.networks.builders.build_recurrent_mask_inference(
            n_features,
            50,
            1,
            True,
            0.0,
            2, ['softmax'],
            normalization_class='InstanceNorm'),
        ml.networks.builders.build_recurrent_mask_inference(
            n_features,
            50,
            1,
            True,
            0.0,
            2, ['softmax'],
            normalization_class='InstanceNorm'),
        ml.networks.builders.build_recurrent_chimera(
            n_features,
            50,
            1,
            True,
            0.0,
            20, ['sigmoid'],
            2, ['softmax'],
            normalization_class='InstanceNorm'),
        ml.networks.builders.build_open_unmix_like(
            n_features,
            50,
            1,
            True,
            .4,
            2,
            1,
            add_embedding=True,
            embedding_size=20,
            embedding_activation=['sigmoid', 'unit_norm'],
        ),
        ml.networks.builders.build_recurrent_end_to_end(
            256,
            256,
            64,
            'sqrt_hann',
            50,
            2,
            True,
            0.0,
            2,
            'sigmoid',
            num_audio_channels=1,
            mask_complex=False,
            rnn_type='lstm',
            mix_key='mix_audio',
            normalization_class='InstanceNorm'),
        ml.networks.builders.build_dual_path_recurrent_end_to_end(
            64,
            16,
            8,
            60,
            30,
            50,
            2,
            True,
            25,
            2,
            'sigmoid',
        )
    ]

    loss_dictionaries = [
        {
            'DeepClusteringLoss': {
                'weight': 1.0
            }
        },
        {
            'L1Loss': {
                'weight': 1.0
            }
        },
        {
            'MSELoss': {
                'weight': 1.0
            }
        },
        {
            'DeepClusteringLoss': {
                'weight': 0.2
            },
            'PermutationInvariantLoss': {
                'args': ['L1Loss'],
                'weight': 0.8
            }
        },
        {
            'DeepClusteringLoss': {
                'weight': 0.2
            },
            'PermutationInvariantLoss': {
                'args': ['L1Loss'],
                'weight': 0.8
            }
        },
        {
            'SISDRLoss': {
                'weight': 1.0,
                'keys': {
                    'audio': 'estimates',
                    'source_audio': 'references'
                }
            }
        },
        {
            'SISDRLoss': {
                'weight': 1.0,
                'keys': {
                    'audio': 'estimates',
                    'source_audio': 'references'
                }
            }
        },
    ]

    def append_keys_to_model(name, model):
        if name == 'end_to_end':
            model.output_keys.extend(
                ['audio', 'recurrent_stack', 'mask', 'estimates'])
        elif name == 'dual_path':
            model.output_keys.extend(
                ['audio', 'mixture_weights', 'dual_path', 'mask', 'estimates'])

    for name, config, loss_dictionary in zip(names, configs,
                                             loss_dictionaries):
        loss_closure = ml.train.closures.Closure(loss_dictionary)

        utils.seed(0, set_cudnn=True)
        model_grad = ml.SeparationModel(config, verbose=True).to(DEVICE)
        append_keys_to_model(name, model_grad)

        all_data = {}
        for data in dataset:
            for key in data:
                if torch.is_tensor(data[key]):
                    data[key] = data[key].float().unsqueeze(0).contiguous().to(
                        DEVICE)
                    if key not in all_data:
                        all_data[key] = data[key]
                    else:
                        all_data[key] = torch.cat([all_data[key], data[key]],
                                                  dim=0)

        # do a forward pass in batched mode
        output_grad = model_grad(all_data)
        _loss = loss_closure.compute_loss(output_grad, all_data)
        # do a backward pass in batched mode
        _loss['loss'].backward()

        plt.figure(figsize=(10, 10))
        utils.visualize_gradient_flow(model_grad.named_parameters())
        plt.tight_layout()
        plt.savefig(f'tests/local/{name}:batch_gradient.png')

        utils.seed(0, set_cudnn=True)
        model_acc = ml.SeparationModel(config).to(DEVICE)
        append_keys_to_model(name, model_acc)

        for i, data in enumerate(dataset):
            for key in data:
                if torch.is_tensor(data[key]):
                    data[key] = data[key].float().unsqueeze(0).contiguous().to(
                        DEVICE)
            # do a forward pass on each item individually
            output_acc = model_acc(data)
            for key in output_acc:
                # make sure the forward pass in batch and forward pass individually match
                # if they don't, then items in a minibatch are talking to each other
                # somehow...
                _data_a = output_acc[key]
                _data_b = output_grad[key][i].unsqueeze(0)
                if name not in config_has_batch_norm:
                    assert torch.allclose(_data_a, _data_b, atol=1e-3)

            _loss = loss_closure.compute_loss(output_acc, data)
            # do a backward pass on each item individually
            _loss['loss'] = _loss['loss'] / len(dataset)
            _loss['loss'].backward()

        plt.figure(figsize=(10, 10))
        utils.visualize_gradient_flow(model_acc.named_parameters())
        plt.tight_layout()
        plt.savefig(f'tests/local/{name}:accumulated_gradient.png')

        # make sure the gradients match between batched and accumulated gradients
        # if they don't, then the items in a batch are talking to each other in the loss
        for param1, param2 in zip(model_grad.parameters(),
                                  model_acc.parameters()):
            assert torch.allclose(param1, param2)
            if name not in config_has_batch_norm:
                if param1.requires_grad and param2.requires_grad:
                    assert torch.allclose(param1.grad.mean(),
                                          param2.grad.mean(),
                                          atol=1e-3)