Пример #1
0
def get_statistics(args, dataset):
    scaler = sklearn.preprocessing.StandardScaler()

    spec = torch.nn.Sequential(
        model.STFT(n_fft=args.nfft, n_hop=args.nhop),
        model.Spectrogram(mono=True)
    )

    dataset_scaler = copy.deepcopy(dataset)
    dataset_scaler.samples_per_track = 1
    dataset_scaler.augmentations = None
    dataset_scaler.random_chunks = False
    dataset_scaler.random_track_mix = False
    dataset_scaler.random_interferer_mix = False
    dataset_scaler.seq_duration = None
    pbar = tqdm.tqdm(range(len(dataset_scaler)), disable=args.quiet)
    for ind in pbar:
        x, y = dataset_scaler[ind]
        pbar.set_description("Compute dataset statistics")
        import pdb; pdb.set_trace()
        X = spec(x[None, ...])
        scaler.partial_fit(np.squeeze(X))

    # set inital input scaler values
    std = np.maximum(
        scaler.scale_,
        1e-4*np.max(scaler.scale_)
    )
    return scaler.mean_, std
Пример #2
0
def get_statistics(args, dataset):
    scaler = sklearn.preprocessing.StandardScaler()

    spec = torch.nn.Sequential(
        model.STFT(n_fft=args.nfft, n_hop=args.nhop),
        model.Spectrogram(mono=False)
    )

    dataset_scaler = copy.deepcopy(dataset)
    dataset_scaler.samples_per_track = 1
    dataset_scaler.augmentations = None
    dataset_scaler.random_chunks = True
    #dataset_scaler.seq_duration = args.seq_dur
    dataset_scaler.seq_duration = 0.0
    pbar = tqdm.tqdm(range(len(dataset_scaler)), disable=args.quiet)
    for ind in pbar:
        x, y = dataset_scaler[ind]
        pbar.set_description("Compute dataset statistics")
        X = spec(x[None, ...])
        #print("HELLO", np.squeeze(X).shape)
        p = np.squeeze(X)
        scaler.partial_fit(np.concatenate((p[:,0],p[:,1]) )) #CHANGED!!

    # set inital input scaler values
    std = np.maximum(
        scaler.scale_,
        1e-4*np.max(scaler.scale_)
    )
    return scaler.mean_, std
Пример #3
0
def get_statistics(args, dataset):

    # dataset is an instance of a torch.utils.data.Dataset class

    scaler = sklearn.preprocessing.StandardScaler()  # tool to compute mean and variance of data

    # define operation that computes magnitude spectrograms
    spec = torch.nn.Sequential(
        model.STFT(n_fft=args.nfft, n_hop=args.nhop),
        model.Spectrogram(mono=True)
    )
    # return a deep copy of dataset:
    # constructs a new compound object and recursively inserts copies of the objects found in the original
    dataset_scaler = copy.deepcopy(dataset)

    dataset_scaler.samples_per_track = 1
    dataset_scaler.augmentations = None  # no scaling of sources before mixing
    dataset_scaler.random_chunks = False  # no random chunking of tracks
    dataset_scaler.random_track_mix = False  # no random accompaniments for vocals
    dataset_scaler.random_interferer_mix = False
    dataset_scaler.seq_duration = None  # if None, the original whole track from musdb is loaded

    # make a progress bar:
    # returns an iterator which acts exactly like the original iterable,
    # but prints a dynamically updating progressbar every time a value is requested.
    pbar = tqdm.tqdm(range(len(dataset_scaler)), disable=args.quiet)

    for ind in pbar:
        out = dataset_scaler[ind]  # x is mix and y is target source in time domain, z is text and ignored here
        x = out[0]
        y = out[1]
        pbar.set_description("Compute dataset statistics")
        X = spec(x[None, ...])  # X is mono magnitude spectrogram, ... means as many ':' as needed

        # X is spectrogram of one full track
        # at this point, X has shape (nb_frames, nb_samples, nb_channels, nb_bins) = (N, 1, 1, F)
        # nb_frames: time steps, nb_bins: frequency bands, nb_samples: batch size

        # online computation of mean and std on X for later scaling
        # after squeezing, X has shape (N, F)
        scaler.partial_fit(np.squeeze(X))  # np.squeeze: remove single-dimensional entries from the shape of an array

    # set inital input scaler values
    # scale_ and mean_ have shape (nb_bins,), standard deviation and mean are computed on each frequency band separately
    # if std of a frequency bin is smaller than m = 1e-4 * (max std of all freq. bins), set it to m
    std = np.maximum(   # maximum compares two arrays element wise and returns the maximum element wise
        scaler.scale_,
        1e-4*np.max(scaler.scale_)  # np.max = np.amax, it returns the max element of one array
    )
    return scaler.mean_, std
Пример #4
0
def get_statistics(args, datasource):
    scaler = sklearn.preprocessing.StandardScaler()

    pbar = tqdm.tqdm(range(len(datasource.mus.tracks)), disable=args.quiet)

    for ind in pbar:
        x = datasource.mus.tracks[ind].audio.T
        audio = nn.Variable([1] + list(x.shape))
        audio.d = x

        target_spec = model.Spectrogram(*model.STFT(audio,
                                                    n_fft=args.nfft,
                                                    n_hop=args.nhop),
                                        mono=(args.nb_channels == 1))

        pbar.set_description("Compute dataset statistics")
        target_spec.forward()
        scaler.partial_fit(np.squeeze(target_spec.d[0]))
    # set inital input scaler values
    std = np.maximum(scaler.scale_, 1e-4 * np.max(scaler.scale_))
    return scaler.mean_, std
Пример #5
0
def train():
    parser, args = get_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Initialize DataIterator for MNIST.
    train_source, valid_source, args = data.load_datasources(
        parser, args, rng=RandomState(42))

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    scaler_mean, scaler_std = get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model.OpenUnmix(input_mean=scaler_mean,
                            input_scale=scaler_std,
                            nb_channels=args.nb_channels,
                            hidden_size=args.hidden_size,
                            n_fft=args.nfft,
                            n_hop=args.nhop,
                            max_bin=max_bin,
                            sample_rate=train_source.sample_rate)

    # Create input variables.
    audio_shape = [args.batch_size] + list(train_source._get_data(0)[0].shape)
    mixture_audio = nn.Variable(audio_shape)
    target_audio = nn.Variable(audio_shape)

    vmixture_audio = nn.Variable(audio_shape)
    vtarget_audio = nn.Variable(audio_shape)

    # create train graph
    pred_spec = unmix(mixture_audio, test=False)
    pred_spec.persistent = True

    target_spec = model.Spectrogram(*model.STFT(target_audio,
                                                n_fft=unmix.n_fft,
                                                n_hop=unmix.n_hop),
                                    mono=(unmix.nb_channels == 1))

    loss = F.mean(F.squared_error(pred_spec, target_spec), axis=1)

    # Create Solver.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # Training loop.
    t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
    es = utils.EarlyStopping(patience=args.patience)

    for epoch in t:
        # TRAINING
        t.set_description("Training Epoch")
        b = tqdm.trange(0,
                        train_source._size // args.batch_size,
                        disable=args.quiet)
        losses = utils.AverageMeter()
        for batch in b:
            mixture_audio.d, target_audio.d = train_iter.next()
            b.set_description("Training Batch")
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.weight_decay(args.weight_decay)
            solver.update()
            losses.update(loss.d.copy().mean())
            b.set_postfix(train_loss=losses.avg)

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(valid_source._size):
            # Create new validation input variables for every batch
            vmixture_audio.d, vtarget_audio.d = valid_iter.next()
            # create validation graph
            vpred_spec = unmix(vmixture_audio, test=True)
            vpred_spec.persistent = True

            vtarget_spec = model.Spectrogram(*model.STFT(vtarget_audio,
                                                         n_fft=unmix.n_fft,
                                                         n_hop=unmix.n_hop),
                                             mono=(unmix.nb_channels == 1))
            vloss = F.mean(F.squared_error(vpred_spec, vtarget_spec), axis=1)

            vloss.forward(clear_buffer=True)
            vlosses.update(vloss.d.copy().mean())

        t.set_postfix(train_loss=losses.avg, val_loss=vlosses.avg)

        stop = es.step(vlosses.avg)
        is_best = vlosses.avg == es.best

        # save current model
        nn.save_parameters(
            os.path.join(args.output, 'checkpoint_%s.h5' % args.target))

        if is_best:
            best_epoch = epoch
            nn.save_parameters(os.path.join(args.output,
                                            '%s.h5' % args.target))

        if stop:
            print("Apply Early Stopping")
            break