Ejemplo n.º 1
0
def batch_norm_update(model,
                      dataset,
                      feature_key,
                      batch_dim=0,
                      device=0 if torch.cuda.is_available() else 'cpu'):
    r"""Updates BatchNorm running_mean, running_var buffers in the model.

    It performs one pass over data in `loader` to estimate the activation
    statistics for BatchNorm layers in the model.

    Args:
        dataset: dataset to compute the activation statistics on.
            Each data batch should be either a dict, or a list/tuple.

        model: model for which we seek to update BatchNorm statistics.

        feature_key: key to get an input tensor to read batch_size from

        device: If set, data will be transferred to :attr:`device`
            before being passed into :attr:`model`.
    """
    if not _check_bn(model):
        return
    was_training = model.training
    model.train()

    model.to(device)

    momenta = {}
    model.apply(_reset_bn)
    model.apply(lambda module: _get_momenta(module, momenta))
    n = 0
    with torch.no_grad():
        for i, example in enumerate(dataset):
            example = example_to_device(example, device)
            b = example[feature_key].size(batch_dim)

            momentum = b / float(n + b)
            for module in momenta.keys():
                module.momentum = momentum

            model(example)

            n += b

    model.apply(lambda module: _set_momenta(module, momenta))
    model.train(was_training)
Ejemplo n.º 2
0
def evaluate_masks(example, model, stft):
    model_out = model(example_to_device(example))
    speech_image = example[DB_K.SPEECH_IMAGE][0]
    speech_pred, image_cont, noise_cont = beamforming(
        example[M_K.OBSERVATION_STFT][0],
        model_out[M_K.SPEECH_MASK_PRED][0].detach().numpy(),
        model_out[M_K.NOISE_MASK_PRED][0].detach().numpy(),
        stft(speech_image),
        stft(example[DB_K.NOISE_IMAGE][0])
    )
    ex_id = example[DB_K.EXAMPLE_ID][0]
    pesq = pb.evaluation.pesq(example[DB_K.SPEECH_IMAGE][0][0],
                              stft.inverse(speech_pred))[0]
    snr = np.mean(-10 * np.log10(np.abs(image_cont) ** 2
                                 / np.abs(noise_cont) ** 2))
    print(ex_id, snr, pesq)
    return ex_id, snr, pesq
Ejemplo n.º 3
0
    def example_to_device(self, example, device=None):
        """
        Transfers `example` to `device` as required by the model. By default,
        the whole example is transferred to `device`, but subclasses can
        override this method to only transfer the required parts of the
        example.

        An example for data that is not required on GPU during training are
        time-domain target signals for an STFT-based model. These are not
        required for loss computation, but are nice to have reported to
        tensorboard.

        Args:
            example: The example to transfer to `device`
            device: The device to transfer `example` to.

        Returns:
            The `example`, either fully or partially transferred to the device.
        """
        return example_to_device(example, device)