def get_statistics(args, dataset): scaler = sklearn.preprocessing.StandardScaler() spec = torch.nn.Sequential( model.STFT(n_fft=args.nfft, n_hop=args.nhop), model.Spectrogram(mono=True) ) dataset_scaler = copy.deepcopy(dataset) dataset_scaler.samples_per_track = 1 dataset_scaler.augmentations = None dataset_scaler.random_chunks = False dataset_scaler.random_track_mix = False dataset_scaler.random_interferer_mix = False dataset_scaler.seq_duration = None pbar = tqdm.tqdm(range(len(dataset_scaler)), disable=args.quiet) for ind in pbar: x, y = dataset_scaler[ind] pbar.set_description("Compute dataset statistics") import pdb; pdb.set_trace() X = spec(x[None, ...]) scaler.partial_fit(np.squeeze(X)) # set inital input scaler values std = np.maximum( scaler.scale_, 1e-4*np.max(scaler.scale_) ) return scaler.mean_, std
def get_statistics(args, dataset): scaler = sklearn.preprocessing.StandardScaler() spec = torch.nn.Sequential( model.STFT(n_fft=args.nfft, n_hop=args.nhop), model.Spectrogram(mono=False) ) dataset_scaler = copy.deepcopy(dataset) dataset_scaler.samples_per_track = 1 dataset_scaler.augmentations = None dataset_scaler.random_chunks = True #dataset_scaler.seq_duration = args.seq_dur dataset_scaler.seq_duration = 0.0 pbar = tqdm.tqdm(range(len(dataset_scaler)), disable=args.quiet) for ind in pbar: x, y = dataset_scaler[ind] pbar.set_description("Compute dataset statistics") X = spec(x[None, ...]) #print("HELLO", np.squeeze(X).shape) p = np.squeeze(X) scaler.partial_fit(np.concatenate((p[:,0],p[:,1]) )) #CHANGED!! # set inital input scaler values std = np.maximum( scaler.scale_, 1e-4*np.max(scaler.scale_) ) return scaler.mean_, std
def get_statistics(args, dataset): # dataset is an instance of a torch.utils.data.Dataset class scaler = sklearn.preprocessing.StandardScaler() # tool to compute mean and variance of data # define operation that computes magnitude spectrograms spec = torch.nn.Sequential( model.STFT(n_fft=args.nfft, n_hop=args.nhop), model.Spectrogram(mono=True) ) # return a deep copy of dataset: # constructs a new compound object and recursively inserts copies of the objects found in the original dataset_scaler = copy.deepcopy(dataset) dataset_scaler.samples_per_track = 1 dataset_scaler.augmentations = None # no scaling of sources before mixing dataset_scaler.random_chunks = False # no random chunking of tracks dataset_scaler.random_track_mix = False # no random accompaniments for vocals dataset_scaler.random_interferer_mix = False dataset_scaler.seq_duration = None # if None, the original whole track from musdb is loaded # make a progress bar: # returns an iterator which acts exactly like the original iterable, # but prints a dynamically updating progressbar every time a value is requested. pbar = tqdm.tqdm(range(len(dataset_scaler)), disable=args.quiet) for ind in pbar: out = dataset_scaler[ind] # x is mix and y is target source in time domain, z is text and ignored here x = out[0] y = out[1] pbar.set_description("Compute dataset statistics") X = spec(x[None, ...]) # X is mono magnitude spectrogram, ... means as many ':' as needed # X is spectrogram of one full track # at this point, X has shape (nb_frames, nb_samples, nb_channels, nb_bins) = (N, 1, 1, F) # nb_frames: time steps, nb_bins: frequency bands, nb_samples: batch size # online computation of mean and std on X for later scaling # after squeezing, X has shape (N, F) scaler.partial_fit(np.squeeze(X)) # np.squeeze: remove single-dimensional entries from the shape of an array # set inital input scaler values # scale_ and mean_ have shape (nb_bins,), standard deviation and mean are computed on each frequency band separately # if std of a frequency bin is smaller than m = 1e-4 * (max std of all freq. bins), set it to m std = np.maximum( # maximum compares two arrays element wise and returns the maximum element wise scaler.scale_, 1e-4*np.max(scaler.scale_) # np.max = np.amax, it returns the max element of one array ) return scaler.mean_, std
def get_statistics(args, datasource): scaler = sklearn.preprocessing.StandardScaler() pbar = tqdm.tqdm(range(len(datasource.mus.tracks)), disable=args.quiet) for ind in pbar: x = datasource.mus.tracks[ind].audio.T audio = nn.Variable([1] + list(x.shape)) audio.d = x target_spec = model.Spectrogram(*model.STFT(audio, n_fft=args.nfft, n_hop=args.nhop), mono=(args.nb_channels == 1)) pbar.set_description("Compute dataset statistics") target_spec.forward() scaler.partial_fit(np.squeeze(target_spec.d[0])) # set inital input scaler values std = np.maximum(scaler.scale_, 1e-4 * np.max(scaler.scale_)) return scaler.mean_, std
def train(): parser, args = get_args() # Get context. ctx = get_extension_context(args.context, device_id=args.device_id) nn.set_default_context(ctx) # Initialize DataIterator for MNIST. train_source, valid_source, args = data.load_datasources( parser, args, rng=RandomState(42)) train_iter = data_iterator(train_source, args.batch_size, RandomState(args.seed), with_memory_cache=False, with_file_cache=False) valid_iter = data_iterator(valid_source, args.batch_size, RandomState(args.seed), with_memory_cache=False, with_file_cache=False) scaler_mean, scaler_std = get_statistics(args, train_source) max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft, args.bandwidth) unmix = model.OpenUnmix(input_mean=scaler_mean, input_scale=scaler_std, nb_channels=args.nb_channels, hidden_size=args.hidden_size, n_fft=args.nfft, n_hop=args.nhop, max_bin=max_bin, sample_rate=train_source.sample_rate) # Create input variables. audio_shape = [args.batch_size] + list(train_source._get_data(0)[0].shape) mixture_audio = nn.Variable(audio_shape) target_audio = nn.Variable(audio_shape) vmixture_audio = nn.Variable(audio_shape) vtarget_audio = nn.Variable(audio_shape) # create train graph pred_spec = unmix(mixture_audio, test=False) pred_spec.persistent = True target_spec = model.Spectrogram(*model.STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop), mono=(unmix.nb_channels == 1)) loss = F.mean(F.squared_error(pred_spec, target_spec), axis=1) # Create Solver. solver = S.Adam(args.lr) solver.set_parameters(nn.get_parameters()) # Training loop. t = tqdm.trange(1, args.epochs + 1, disable=args.quiet) es = utils.EarlyStopping(patience=args.patience) for epoch in t: # TRAINING t.set_description("Training Epoch") b = tqdm.trange(0, train_source._size // args.batch_size, disable=args.quiet) losses = utils.AverageMeter() for batch in b: mixture_audio.d, target_audio.d = train_iter.next() b.set_description("Training Batch") solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() losses.update(loss.d.copy().mean()) b.set_postfix(train_loss=losses.avg) # VALIDATION vlosses = utils.AverageMeter() for batch in range(valid_source._size): # Create new validation input variables for every batch vmixture_audio.d, vtarget_audio.d = valid_iter.next() # create validation graph vpred_spec = unmix(vmixture_audio, test=True) vpred_spec.persistent = True vtarget_spec = model.Spectrogram(*model.STFT(vtarget_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop), mono=(unmix.nb_channels == 1)) vloss = F.mean(F.squared_error(vpred_spec, vtarget_spec), axis=1) vloss.forward(clear_buffer=True) vlosses.update(vloss.d.copy().mean()) t.set_postfix(train_loss=losses.avg, val_loss=vlosses.avg) stop = es.step(vlosses.avg) is_best = vlosses.avg == es.best # save current model nn.save_parameters( os.path.join(args.output, 'checkpoint_%s.h5' % args.target)) if is_best: best_epoch = epoch nn.save_parameters(os.path.join(args.output, '%s.h5' % args.target)) if stop: print("Apply Early Stopping") break