예제 #1
0
def load_model(targets, model_name='umxhq', device='cpu'):
    """
    target model path can be either <target>.pth, or <target>-sha256.pth
    (as used on torchub)
    """
    model_path = Path(model_name).expanduser()
    if not model_path.exists():
        raise NotImplementedError
    else:
        # load model from disk
        with open(Path(model_path,
                       str(len(targets)) + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path) / "model.pth"
        state = torch.load(target_model_path, map_location=device)

        max_bin = utils.bandwidth_to_max_bin(44100, results['args']['nfft'],
                                             results['args']['bandwidth'])

        unmix = model.OpenUnmixSingle(
            n_fft=results['args']['nfft'],
            n_hop=results['args']['nhop'],
            nb_channels=results['args']['nb_channels'],
            hidden_size=results['args']['hidden_size'],
            max_bin=max_bin)

        unmix.load_state_dict(state)
        unmix.stft.center = True
        unmix.eval()
        unmix.to(device)
        print('loadmodel function done')
        return unmix
예제 #2
0
def load_model(target, model_name='umxhq', device='cpu'):
    """
    target model path can be either <target>.pth, or <target>-sha256.pth
    (as used on torchub)
    """
    model_path = Path(model_name).expanduser()
    if not model_path.exists():
        # model path does not exist, use hubconf model
        try:
            # disable progress bar
            err = io.StringIO()
            with redirect_stderr(err):
                return torch.hub.load(
                    'sigsep/open-unmix-pytorch',
                    model_name,
                    target=target,
                    device=device,
                    pretrained=True
                )
            print(err.getvalue())
        except AttributeError:
            raise NameError('Model does not exist on torchhub')
            # assume model is a path to a local model_name direcotry
    else:
        # load model from disk
        with open(Path(model_path, target + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = next(Path(model_path).glob("%s*.pth" % target))
        state = torch.load(
            target_model_path,
            map_location=device
        )

        max_bin = utils.bandwidth_to_max_bin(
            state['sample_rate'],
            results['args']['nfft'],
            results['args']['bandwidth']
        )

        unmix = model.OpenUnmix(
            n_fft=results['args']['nfft'],
            n_hop=results['args']['nhop'],
            nb_channels=results['args']['nb_channels'],
            hidden_size=results['args']['hidden_size'],
            max_bin=max_bin
        )

        unmix.load_state_dict(state)
        unmix.stft.center = True
        unmix.eval()
        unmix.to(device)
        return unmix
예제 #3
0
def umxhq(target='vocals', device='cpu', pretrained=True, *args, **kwargs):
    """
    Open Unmix 2-channel/stereo BiLSTM Model trained on MUSDB18-HQ

    Args:
        target (str): select the target for the source to be separated.
                      Supported targets are
                        ['vocals', 'drums', 'bass', 'other']
        pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ
        device (str): selects device to be used for inference
    """
    # set urls for weights
    target_urls = {
        'bass':
        'https://zenodo.org/api/files/1c8f83c5-33a5-4f59-b109-721fdd234875/bass-8d85a5bd.pth',
        'drums':
        'https://zenodo.org/api/files/1c8f83c5-33a5-4f59-b109-721fdd234875/drums-9619578f.pth',
        'other':
        'https://zenodo.org/api/files/1c8f83c5-33a5-4f59-b109-721fdd234875/other-b52fbbf7.pth',
        'vocals':
        'https://zenodo.org/api/files/1c8f83c5-33a5-4f59-b109-721fdd234875/vocals-b62c91ce.pth'
    }

    from model import OpenUnmix

    # determine the maximum bin count for a 16khz bandwidth model
    max_bin = utils.bandwidth_to_max_bin(rate=44100,
                                         n_fft=4096,
                                         bandwidth=16000)

    # load open unmix model
    unmix = OpenUnmix(n_fft=4096,
                      n_hop=1024,
                      nb_channels=2,
                      hidden_size=512,
                      max_bin=max_bin)

    # enable centering of stft to minimize reconstruction error
    if pretrained:
        state_dict = torch.hub.load_state_dict_from_url(target_urls[target],
                                                        map_location=device)
        unmix.load_state_dict(state_dict)
        unmix.stft.center = True
        unmix.eval()

    return unmix.to(device)
예제 #4
0
def umx(target='vocals', device='cpu', pretrained=True, *args, **kwargs):
    """
    Open Unmix 2-channel/stereo BiLSTM Model trained on MUSDB18

    Args:
        target (str): select the target for the source to be separated.
                      Supported targets are
                        ['vocals', 'drums', 'bass', 'other']
        pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ
        device (str): selects device to be used for inference
    """
    # set urls for weights
    target_urls = {
        'bass':
        'https://zenodo.org/api/files/d6105b95-8c52-430c-84ce-bd14b803faaf/bass-646024d3.pth',
        'drums':
        'https://zenodo.org/api/files/d6105b95-8c52-430c-84ce-bd14b803faaf/drums-5a48008b.pth',
        'other':
        'https://zenodo.org/api/files/d6105b95-8c52-430c-84ce-bd14b803faaf/other-f8e132cc.pth',
        'vocals':
        'https://zenodo.org/api/files/d6105b95-8c52-430c-84ce-bd14b803faaf/vocals-c8df74a5.pth'
    }

    from model import OpenUnmix

    # determine the maximum bin count for a 16khz bandwidth model
    max_bin = utils.bandwidth_to_max_bin(rate=44100,
                                         n_fft=4096,
                                         bandwidth=16000)

    # load open unmix model
    unmix = OpenUnmix(n_fft=4096,
                      n_hop=1024,
                      nb_channels=2,
                      hidden_size=512,
                      max_bin=max_bin)

    # enable centering of stft to minimize reconstruction error
    if pretrained:
        state_dict = torch.hub.load_state_dict_from_url(target_urls[target],
                                                        map_location=device)
        unmix.load_state_dict(state_dict)
        unmix.stft.center = True
        unmix.eval()

    return unmix.to(device)
예제 #5
0
def load_model(target, model_name='umxhq', device='cpu', chkpnt=False):
    """
    target model path can be either <target>.pth, or <target>-sha256.pth
    (as used on torchub)
    """
    model_path = Path(model_name).expanduser()
    if not model_path.exists():
        print("Can't find model! Please check model_path.")

    else:
        # load model from disk
        with open(Path(model_path, target + '.json'), 'r') as stream:
            results = json.load(stream)
        if not chkpnt:
            target_model_path = next(Path(model_path).glob("%s*.pth" % target))
            state = torch.load(target_model_path, map_location=device)
        else:  # using chkpnt instead of pth
            target_model_path = next(
                Path(model_path).glob("%s*.chkpnt" % target))
            state = torch.load(target_model_path,
                               #                    map_location=device
                               )['state_dict']

        max_bin = utils.bandwidth_to_max_bin(state['sample_rate'],
                                             results['args']['nfft'],
                                             results['args']['bandwidth'])

        unmix = model.OpenUnmix(n_fft=results['args']['nfft'],
                                n_hop=results['args']['nhop'],
                                nb_channels=results['args']['nb_channels'],
                                hidden_size=results['args']['hidden_size'],
                                max_bin=max_bin)

        unmix.load_state_dict(state)
        unmix.stft.center = True
        unmix.eval()
        unmix.to(device)
        return unmix
예제 #6
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
예제 #7
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    parser.add_argument('--target',
                        type=str,
                        default='vocals',
                        help='target source (will be passed to the dataset)')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="aligned",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')
    parser.add_argument('--root',
                        type=str,
                        help='root path of dataset',
                        default='../rec_data_new/')
    parser.add_argument('--output',
                        type=str,
                        default="../out_unmix/model_new_data_aug_tl",
                        help='provide output path base folder name')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder')
    parser.add_argument('--model',
                        type=str,
                        help='Path to checkpoint folder',
                        default='umxhq')

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.0000000001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')
    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=4,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)
    print("TRAIN DATASET", train_dataset)
    print("VALID DATASET", valid_dataset)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **dataloader_kwargs)
    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    # =============================================================================
    #     if args.model:
    #         scaler_mean = None
    #         scaler_std = None
    #
    #     else:
    # =============================================================================
    scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model.OpenUnmix(input_mean=scaler_mean,
                            input_scale=scaler_std,
                            nb_channels=args.nb_channels,
                            hidden_size=args.hidden_size,
                            n_fft=args.nfft,
                            n_hop=args.nhop,
                            max_bin=max_bin,
                            sample_rate=train_dataset.sample_rate).to(device)

    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        # disable progress bar
        err = io.StringIO()
        with redirect_stderr(err):
            unmix = torch.hub.load('sigsep/open-unmix-pytorch',
                                   'umxhq',
                                   target=args.target,
                                   device=device,
                                   pretrained=True)
# =============================================================================
#         model_path = Path(args.model).expanduser()
#         with open(Path(model_path, args.target + '.json'), 'r') as stream:
#             results = json.load(stream)
#
#         target_model_path = Path(model_path, args.target + ".chkpnt")
#         checkpoint = torch.load(target_model_path, map_location=device)
#         unmix.load_state_dict(checkpoint['state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer'])
#         scheduler.load_state_dict(checkpoint['scheduler'])
#         # train for another epochs_trained
#         t = tqdm.trange(
#             results['epochs_trained'],
#             results['epochs_trained'] + args.epochs + 1,
#             disable=args.quiet
#         )
#         train_losses = results['train_loss_history']
#         valid_losses = results['valid_loss_history']
#         train_times = results['train_time_history']
#         best_epoch = results['best_epoch']
#         es.best = results['best_loss']
#         es.num_bad_epochs = results['num_bad_epochs']
#     # else start from 0
# =============================================================================

    t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
    train_losses = []
    valid_losses = []
    train_times = []
    best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss = train(args, unmix, device, train_sampler, optimizer)
        valid_loss = valid(args, unmix, device, valid_sampler)
        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        from matplotlib import pyplot as plt

        plt.figure(figsize=(16, 12))
        plt.subplot(2, 2, 1)
        plt.title("Training loss")
        plt.plot(train_losses, label="Training")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        #plt.savefig(Path(target_path, "train_plot.pdf"))

        plt.figure(figsize=(16, 12))
        plt.subplot(2, 2, 2)
        plt.title("Validation loss")
        plt.plot(valid_losses, label="Validation")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        #plt.savefig(Path(target_path, "val_plot.pdf"))

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target)

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'valid_loss_history': valid_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path, args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break
예제 #8
0
def train():
    parser, args = get_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Initialize DataIterator for MNIST.
    train_source, valid_source, args = data.load_datasources(
        parser, args, rng=RandomState(42))

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    scaler_mean, scaler_std = get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model.OpenUnmix(input_mean=scaler_mean,
                            input_scale=scaler_std,
                            nb_channels=args.nb_channels,
                            hidden_size=args.hidden_size,
                            n_fft=args.nfft,
                            n_hop=args.nhop,
                            max_bin=max_bin,
                            sample_rate=train_source.sample_rate)

    # Create input variables.
    audio_shape = [args.batch_size] + list(train_source._get_data(0)[0].shape)
    mixture_audio = nn.Variable(audio_shape)
    target_audio = nn.Variable(audio_shape)

    vmixture_audio = nn.Variable(audio_shape)
    vtarget_audio = nn.Variable(audio_shape)

    # create train graph
    pred_spec = unmix(mixture_audio, test=False)
    pred_spec.persistent = True

    target_spec = model.Spectrogram(*model.STFT(target_audio,
                                                n_fft=unmix.n_fft,
                                                n_hop=unmix.n_hop),
                                    mono=(unmix.nb_channels == 1))

    loss = F.mean(F.squared_error(pred_spec, target_spec), axis=1)

    # Create Solver.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # Training loop.
    t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
    es = utils.EarlyStopping(patience=args.patience)

    for epoch in t:
        # TRAINING
        t.set_description("Training Epoch")
        b = tqdm.trange(0,
                        train_source._size // args.batch_size,
                        disable=args.quiet)
        losses = utils.AverageMeter()
        for batch in b:
            mixture_audio.d, target_audio.d = train_iter.next()
            b.set_description("Training Batch")
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.weight_decay(args.weight_decay)
            solver.update()
            losses.update(loss.d.copy().mean())
            b.set_postfix(train_loss=losses.avg)

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(valid_source._size):
            # Create new validation input variables for every batch
            vmixture_audio.d, vtarget_audio.d = valid_iter.next()
            # create validation graph
            vpred_spec = unmix(vmixture_audio, test=True)
            vpred_spec.persistent = True

            vtarget_spec = model.Spectrogram(*model.STFT(vtarget_audio,
                                                         n_fft=unmix.n_fft,
                                                         n_hop=unmix.n_hop),
                                             mono=(unmix.nb_channels == 1))
            vloss = F.mean(F.squared_error(vpred_spec, vtarget_spec), axis=1)

            vloss.forward(clear_buffer=True)
            vlosses.update(vloss.d.copy().mean())

        t.set_postfix(train_loss=losses.avg, val_loss=vlosses.avg)

        stop = es.step(vlosses.avg)
        is_best = vlosses.avg == es.best

        # save current model
        nn.save_parameters(
            os.path.join(args.output, 'checkpoint_%s.h5' % args.target))

        if is_best:
            best_epoch = epoch
            nn.save_parameters(os.path.join(args.output,
                                            '%s.h5' % args.target))

        if stop:
            print("Apply Early Stopping")
            break
예제 #9
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')
    # Loss parameters
    parser.add_argument('--loss',
                        type=str,
                        default="L2freq",
                        choices=[
                            'L2freq', 'L1freq', 'L2time', 'L1time', 'L2mask',
                            'L1mask', 'SISDRtime', 'SISDRfreq', 'MinSNRsdsdr',
                            'CrossEntropy', 'BinaryCrossEntropy', 'LogL2time',
                            'LogL1time', 'LogL2freq', 'LogL1freq', 'PSA',
                            'SNRPSA', 'Dissimilarity'
                        ],
                        help='kind of loss used during training')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="musdb",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')

    parser.add_argument('--root', type=str, help='root path of dataset')
    parser.add_argument('--output',
                        type=str,
                        default="open-unmix",
                        help='provide output path base folder name')
    parser.add_argument('--model', type=str, help='Path to checkpoint folder')

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--reduce-samples',
                        type=int,
                        default=1,
                        help="reduce training samples by factor n")

    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.00001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')
    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=0,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    train_dataset, valid_dataset, args = data.load_datasets(parser, args)

    num_train = len(train_dataset)
    indices = list(range(num_train))

    # shuffle train indices once and for all
    np.random.seed(args.seed)
    np.random.shuffle(indices)

    if args.reduce_samples > 1:
        split = int(np.floor(num_train / args.reduce_samples))
        train_idx = indices[:split]
    else:
        train_idx = indices
    sampler = SubsetRandomSampler(train_idx)
    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                sampler=sampler,
                                                **dataloader_kwargs)

    stats_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=1,
                                                sampler=sampler,
                                                **dataloader_kwargs)

    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    if args.model:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, stats_sampler)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)
    # SNRPSA: de-compress the scaler in order to avoid an exploding gradient from the  uncompressed initial statistics
    if args.loss == 'SNRPSA':
        power = 2
    else:
        power = 1

    unmix = model.OpenUnmixSingle(
        n_fft=4096,
        n_hop=1024,
        input_is_spectrogram=False,
        hidden_size=args.hidden_size,
        nb_channels=args.nb_channels,
        sample_rate=train_dataset.sample_rate,
        nb_layers=3,
        input_mean=scaler_mean,
        input_scale=scaler_std,
        max_bin=max_bin,
        unidirectional=args.unidirectional,
        power=power,
    ).to(device)
    print('learning rate:')
    print(args.lr)
    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        print('LOADING MODEL')
        model_path = Path(args.model).expanduser()
        with open(Path(model_path,
                       str(len(args.targets)) + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path, "model.chkpnt")
        checkpoint = torch.load(target_model_path, map_location=device)
        unmix.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        # train for another epochs_trained
        t = tqdm.trange(results['epochs_trained'],
                        results['epochs_trained'] + args.epochs + 1,
                        disable=args.quiet)
        train_losses = results['train_loss_history']
        valid_losses = results['valid_loss_history']
        train_times = results['train_time_history']
        best_epoch = results['best_epoch']
        es.best = results['best_loss']
        es.num_bad_epochs = results['num_bad_epochs']
        print('Model loaded')
    # else start from 0
    else:
        t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
        train_losses = []
        valid_losses = []
        train_times = []
        best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss = train(args, unmix, device, train_sampler, optimizer)
        valid_loss = valid(args, unmix, device, valid_sampler)
        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
        )

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'valid_loss_history': valid_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path,
                       str(len(args.targets)) + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break
예제 #10
0
def separate(audio,
             model_path='models/x-umx.h5',
             niter=1,
             softmask=False,
             alpha=1.0,
             residual_model=False):
    """
    Performing the separation on audio input
    Parameters
    ----------
    audio: np.ndarray [shape=(nb_samples, nb_channels, nb_timesteps)]
        mixture audio
    model_path: str
        path to model folder, defaults to `models/`
    niter: int
         Number of EM steps for refining initial estimates in a
         post-processing stage, defaults to 1.
    softmask: boolean
        if activated, then the initial estimates for the sources will
        be obtained through a ratio mask of the mixture STFT, and not
        by using the default behavior of reconstructing waveforms
        by using the mixture phase, defaults to False
    alpha: float
        changes the exponent to use for building ratio masks, defaults to 1.0
    residual_model: boolean
        computes a residual target, for custom separation scenarios
        when not all targets are available, defaults to False
    Returns
    -------
    estimates: `dict` [`str`, `np.ndarray`]
        dictionary of all estimates as performed by the separation model.
    """
    # convert numpy audio to NNabla Variable
    audio_nn = nn.Variable.from_numpy_array(audio.T[None, ...])
    source_names = []
    V = []
    max_bin = bandwidth_to_max_bin(sample_rate=44100,
                                   n_fft=4096,
                                   bandwidth=16000)

    sources = ['bass', 'drums', 'vocals', 'other']
    for j, target in enumerate(sources):
        if j == 0:
            unmix_target = model.OpenUnmix_CrossNet(max_bin=max_bin)
            unmix_target.is_predict = True
            nn.load_parameters(model_path)
            mix_spec, msk, _ = unmix_target(audio_nn, test=True)
        Vj = msk[Ellipsis, j * 2:j * 2 + 2, :] * mix_spec
        if softmask:
            # only exponentiate the model if we use softmask
            Vj = Vj**alpha
        # output is nb_frames, nb_samples, nb_channels, nb_bins
        V.append(Vj.d[:, 0, ...])  # remove sample dim
        source_names += [target]
    V = np.transpose(np.array(V), (1, 3, 2, 0))

    real, imag = model.STFT(audio_nn, center=True)

    # convert to complex numpy type
    X = real.d + imag.d * 1j
    X = X[0].transpose(2, 1, 0)

    if residual_model or len(sources) == 1:
        V = norbert.residual_model(V, X, alpha if softmask else 1)
        source_names += (['residual']
                         if len(sources) > 1 else ['accompaniment'])

    Y = norbert.wiener(V,
                       X.astype(np.complex128),
                       niter,
                       use_softmask=softmask)

    estimates = {}
    for j, name in enumerate(source_names):
        audio_hat = istft(Y[..., j].T,
                          n_fft=unmix_target.n_fft,
                          n_hopsize=unmix_target.n_hop)
        estimates[name] = audio_hat.T

    return estimates
예제 #11
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    # =============================================================================
    #     parser.add_argument('--target', type=str, default='vocals',
    #                         help='target source (will be passed to the dataset)')
    #
    # =============================================================================
    parser.add_argument('--target',
                        type=str,
                        default='tabla',
                        help='target source (will be passed to the dataset)')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="aligned",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')
    parser.add_argument('--root',
                        type=str,
                        help='root path of dataset',
                        default='../rec_data_final/')
    parser.add_argument('--output',
                        type=str,
                        default="../new_models/model_tabla_mtl_ourmix_1",
                        help='provide output path base folder name')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_mse_pretrain1')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default="../out_unmix/model_new_data_aug_tabla_mse_pretrain8" )
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_bce_finetune2')
    parser.add_argument('--model', type=str, help='Path to checkpoint folder')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='umxhq')
    parser.add_argument(
        '--onset-model',
        type=str,
        help='Path to onset detection model weights',
        default=
        "/media/Sharedata/rohit/cnn-onset-det/models/apr4/saved_model_0_80mel-0-16000_1ch_44100.pt"
    )

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.00001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weighting of different loss components')
    parser.add_argument(
        '--finetune',
        type=int,
        default=0,
        help=
        'If true(1), then optimiser states from checkpoint model are reset (required for bce finetuning), false if aim is to resume training from where it was left off'
    )
    parser.add_argument('--onset-thresh',
                        type=float,
                        default=0.3,
                        help='Threshold above which onset is said to occur')
    parser.add_argument(
        '--binarise',
        type=int,
        default=0,
        help=
        'If=1(true), then target novelty function is made binary, if=0(false), then left as it is'
    )
    parser.add_argument(
        '--onset-trainable',
        type=int,
        default=0,
        help=
        'If=1(true), then onsetCNN will also get trained in finetuning stage, if=0(false) then kept fixed'
    )

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')

    # =============================================================================
    #     parser.add_argument('--nfft', type=int, default=2048,
    #                         help='STFT fft size and window size')
    #     parser.add_argument('--nhop', type=int, default=512,
    #                         help='STFT hop size')
    # =============================================================================

    parser.add_argument('--n-mels',
                        type=int,
                        default=80,
                        help='Number of bins in mel spectrogram')

    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=4,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    torch.autograd.set_detect_anomaly(True)

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)
    print("TRAIN DATASET", train_dataset)
    print("VALID DATASET", valid_dataset)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **dataloader_kwargs)
    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    if args.model:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model_mtl.OpenUnmix_mtl(
        input_mean=scaler_mean,
        input_scale=scaler_std,
        nb_channels=args.nb_channels,
        hidden_size=args.hidden_size,
        n_fft=args.nfft,
        n_hop=args.nhop,
        max_bin=max_bin,
        sample_rate=train_dataset.sample_rate).to(device)

    #Read trained onset detection network (Model through which target spectrogram is passed)
    detect_onset = model.onsetCNN().to(device)
    detect_onset.load_state_dict(
        torch.load(args.onset_model, map_location='cuda:0'))

    #Model through which separated output is passed
    # detect_onset_training = model.onsetCNN().to(device)
    # detect_onset_training.load_state_dict(torch.load(args.onset_model, map_location='cuda:0'))

    for child in detect_onset.children():
        for param in child.parameters():
            param.requires_grad = False

    #If onset trainable is false, then we want to keep the weights of this moel fixed
    # if (args.onset_trainable == 0):
    #     for child in detect_onset_training.children():
    #         for param in child.parameters():
    #             param.requires_grad = False

    # #FOR CHECKING, REMOVE LATER
    # for child in detect_onset_training.children():
    #     for param in child.parameters():
    #         print(param.requires_grad)

    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        model_path = Path(args.model).expanduser()
        with open(Path(model_path, args.target + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path, args.target + ".chkpnt")
        checkpoint = torch.load(target_model_path, map_location=device)
        unmix.load_state_dict(checkpoint['state_dict'])

        #Only when onse is trainable and when that finetuning is being resumed from a point where it is left off, then read the onset state_dict
        # if ((args.onset_trainable==1)and(args.finetune==0)):
        #     detect_onset_training.load_state_dict(checkpoint['onset_state_dict'])
        #     print("Reading saved onset model")
        # else:
        #     print("Not reading saved onset model")

        if (args.finetune == 0):
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            # train for another epochs_trained
            t = tqdm.trange(results['epochs_trained'],
                            results['epochs_trained'] + args.epochs + 1,
                            disable=args.quiet)
            print("PICKUP WHERE LEFT OFF", args.finetune)
            train_losses = results['train_loss_history']
            train_mse_losses = results['train_mse_loss_history']
            train_bce_losses = results['train_bce_loss_history']
            valid_losses = results['valid_loss_history']
            valid_mse_losses = results['valid_mse_loss_history']
            valid_bce_losses = results['valid_bce_loss_history']
            train_times = results['train_time_history']
            best_epoch = results['best_epoch']

            es.best = results['best_loss']
            es.num_bad_epochs = results['num_bad_epochs']

        else:
            t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
            train_losses = []
            train_mse_losses = []
            train_bce_losses = []
            print("NOT PICKUP WHERE LEFT OFF", args.finetune)
            valid_losses = []
            valid_mse_losses = []
            valid_bce_losses = []

            train_times = []
            best_epoch = 0

        #es.best = results['best_loss']
        #es.num_bad_epochs = results['num_bad_epochs']
    # else start from 0
    else:
        t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
        train_losses = []
        train_mse_losses = []
        train_bce_losses = []

        valid_losses = []
        valid_mse_losses = []
        valid_bce_losses = []

        train_times = []
        best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss, train_mse_loss, train_bce_loss = train(
            args,
            unmix,
            device,
            train_sampler,
            optimizer,
            detect_onset=detect_onset)
        #train_mse_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[1]
        #train_bce_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[2]

        valid_loss, valid_mse_loss, valid_bce_loss = valid(
            args, unmix, device, valid_sampler, detect_onset=detect_onset)
        #valid_mse_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[1]
        #valid_bce_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[2]

        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        train_mse_losses.append(train_mse_loss)
        train_bce_losses.append(train_bce_loss)

        valid_losses.append(valid_loss)
        valid_mse_losses.append(valid_mse_loss)
        valid_bce_losses.append(valid_bce_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        #from matplotlib import pyplot as plt

        # =============================================================================
        #         plt.figure(figsize=(16,12))
        #         plt.subplot(2, 2, 1)
        #         plt.title("Training loss")
        #         plt.plot(train_losses,label="Training")
        #         plt.xlabel("Iterations")
        #         plt.ylabel("Loss")
        #         plt.legend()
        #         plt.show()
        #         #plt.savefig(Path(target_path, "train_plot.pdf"))
        #
        #         plt.figure(figsize=(16,12))
        #         plt.subplot(2, 2, 2)
        #         plt.title("Validation loss")
        #         plt.plot(valid_losses,label="Validation")
        #         plt.xlabel("Iterations")
        #         plt.ylabel("Loss")
        #         plt.legend()
        #         plt.show()
        #         #plt.savefig(Path(target_path, "val_plot.pdf"))
        # =============================================================================

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'onset_state_dict': detect_onset.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target)

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'train_mse_loss_history': train_mse_losses,
            'train_bce_loss_history': train_bce_losses,
            'valid_loss_history': valid_losses,
            'valid_mse_loss_history': valid_mse_losses,
            'valid_bce_loss_history': valid_bce_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path, args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break


# =============================================================================
#     plt.figure(figsize=(16,12))
#     plt.subplot(2, 2, 1)
#     plt.title("Training loss")
#     #plt.plot(train_losses,label="Training")
#     plt.plot(train_losses,label="Training")
#     plt.xlabel("Iterations")
#     plt.ylabel("Loss")
#     plt.legend()
#     #plt.show()
#
#     plt.figure(figsize=(16,12))
#     plt.subplot(2, 2, 2)
#     plt.title("Validation loss")
#     plt.plot(valid_losses,label="Validation")
#     plt.xlabel("Iterations")
#     plt.ylabel("Loss")
#     plt.legend()
#     plt.show()
#     plt.savefig(Path(target_path, "train_val_plot.pdf"))
#     #plt.savefig(Path(target_path, "train_plot.pdf"))
# =============================================================================

    print("TRAINING DONE!!")

    plt.figure()
    plt.title("Training loss")
    plt.plot(train_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_plot.pdf"))

    plt.figure()
    plt.title("Validation loss")
    plt.plot(valid_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_plot.pdf"))

    plt.figure()
    plt.title("Training BCE loss")
    plt.plot(train_bce_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_bce_plot.pdf"))

    plt.figure()
    plt.title("Validation BCE loss")
    plt.plot(valid_bce_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_bce_plot.pdf"))

    plt.figure()
    plt.title("Training MSE loss")
    plt.plot(train_mse_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_mse_plot.pdf"))

    plt.figure()
    plt.title("Validation MSE loss")
    plt.plot(valid_mse_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_mse_plot.pdf"))
예제 #12
0
def separate(audio, args):
    """
    Performing the separation on audio input
    Parameters
    ----------
    audio: np.ndarray [shape=(nb_timesteps, nb_channels)]
        mixture audio
    args : ArgumentParser
        ArgumentParser for OpenUnmix_CrossNet(X-UMX)/OpenUnmix(UMX) Inference

    Returns
    -------
    estimates: `dict` [`str`, `np.ndarray`]
        dictionary of all estimates as performed by the separation model.
    """

    # convert numpy audio to NNabla Variable
    audio_nn = nn.Variable.from_numpy_array(audio.T[None, ...])
    source_names = []
    V = []
    max_bin = bandwidth_to_max_bin(sample_rate=44100,
                                   n_fft=4096,
                                   bandwidth=16000)

    if not args.umx_infer:
        # Run X-UMX Inference
        nn.load_parameters(args.model)
        for j, target in enumerate(args.targets):
            if j == 0:
                unmix_target = model.OpenUnmix_CrossNet(max_bin=max_bin,
                                                        is_predict=True)
                mix_spec, msk, _ = unmix_target(audio_nn, test=True)
                # Network output is (nb_frames, nb_samples, nb_channels, nb_bins)
            V.append((msk[Ellipsis, j * 2:j * 2 + 2, :] * mix_spec).d[:, 0,
                                                                      ...])
            source_names += [target]
    else:
        # Run UMX Inference
        for j, target in enumerate(args.targets):
            with nn.parameter_scope(target):
                unmix_target = model.OpenUnmix(max_bin=max_bin)
                nn.load_parameters(f"{os.path.join(args.model, target)}.h5")
                # Network output is (nb_frames, nb_samples, nb_channels, nb_bins)
                V.append(unmix_target(audio_nn, test=True).d[:, 0, ...])
            source_names += [target]

    V = np.transpose(np.array(V), (1, 3, 2, 0))
    if args.softmask:
        # only exponentiate the model if we use softmask
        V = V**args.alpha

    real, imag = model.get_stft(audio_nn, center=True)

    # convert to complex numpy type
    X = real.d + imag.d * 1j
    X = X[0].transpose(2, 1, 0)

    if args.residual_model or len(args.targets) == 1:
        V = norbert.residual_model(V, X, args.alpha if args.softmask else 1)
        source_names += (['residual']
                         if len(args.targets) > 1 else ['accompaniment'])

    Y = norbert.wiener(V,
                       X.astype(np.complex128),
                       args.niter,
                       use_softmask=args.softmask)

    estimates = {}
    for j, name in enumerate(source_names):
        audio_hat = istft(Y[..., j].T,
                          n_fft=unmix_target.n_fft,
                          n_hopsize=unmix_target.n_hop)
        estimates[name] = audio_hat.T

    return estimates
예제 #13
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB18.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(
        train_source,
        args.batch_size,
        RandomState(args.seed),
        with_memory_cache=False,
    )

    valid_iter = data_iterator(
        valid_source,
        1,
        RandomState(args.seed),
        with_memory_cache=False,
    )

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    # Change max_iter, learning_rate and weight_decay according no. of gpu devices for multi-gpu training.
    default_batch_size = 16
    train_scale_factor = (comm.n_procs * args.batch_size) / default_batch_size
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * train_scale_factor
    args.lr = args.lr * train_scale_factor

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    # clear cache memory
    ext.clear_memory_cache()

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    # Get X-UMX/UMX computation graph and variables as namedtuple
    model = get_model(args, scaler_mean, scaler_std, max_bin=max_bin)

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # AverageMeter for mean loss calculation over the epoch
    losses = utils.AverageMeter()

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses.reset()
        for batch in range(max_iter):
            model.mixture_audio.d, model.target_audio.d = train_iter.next()
            solver.zero_grad()
            model.loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                model.loss.backward(clear_buffer=True,
                                    communicator_callbacks=all_reduce_callback)
            else:
                model.loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(model.loss.d.copy(), args.batch_size)
        training_loss = losses.get_avg()

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        losses.reset()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                model.vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                model.vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                model.vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += model.vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            losses.update(loss_tmp.data.copy(), 1)
        validation_loss = losses.get_avg()

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                best_epoch = epoch
                # save best model
                if args.umx_train:
                    nn.save_parameters(os.path.join(args.output,
                                                    'best_umx.h5'))
                else:
                    nn.save_parameters(
                        os.path.join(args.output, 'best_xumx.h5'))

        if args.umx_train:
            # Early stopping for UMX after `args.patience` (140) number of epochs
            if stop:
                print("Apply Early Stopping")
                break
예제 #14
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    parser.add_argument('--target', type=str, default='vocals',
                        help='target source (will be passed to the dataset)')

    # experiment tag which will determine output folder in trained models, tensorboard name, etc.
    parser.add_argument('--tag', type=str)


    # allow to pass a comment about the experiment
    parser.add_argument('--comment', type=str, help='comment about the experiment')

    args, _ = parser.parse_known_args()

    # Dataset paramaters
    parser.add_argument('--dataset', type=str, default="musdb",
                        choices=[
                            'musdb_lyrics', 'timit_music', 'blended', 'nus', 'nus_train'
                        ],
                        help='Name of the dataset.')

    parser.add_argument('--root', type=str, help='root path of dataset')
    parser.add_argument('--output', type=str, default="trained_models/{}/".format(args.tag),
                        help='provide output path base folder name')

    parser.add_argument('--wst-model', type=str, help='Path to checkpoint folder for warmstart')

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument('--patience', type=int, default=140,
                        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience', type=int, default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma', type=float, default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay', type=float, default=0.00001,
                        help='weight decay')
    parser.add_argument('--seed', type=int, default=0, metavar='S',
                        help='random seed (default: 0)')

    parser.add_argument('--alignment-from', type=str, default=None)
    parser.add_argument('--fake-alignment', action='store_true', default=False)


    # Model Parameters
    parser.add_argument('--unidirectional', action='store_true', default=False,
                        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft', type=int, default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024,
                        help='STFT hop size')
    parser.add_argument('--hidden-size', type=int, default=512,
                        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth', type=int, default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels', type=int, default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers', type=int, default=0,
                        help='Number of workers for dataloader.')
    parser.add_argument('--nb-audio-encoder-layers', type=int, default=2)
    parser.add_argument('--nb-layers', type=int, default=3)
    # name of the model class in model.py that should be used
    parser.add_argument('--architecture', type=str)
    # select attention type if applicable for selected model
    parser.add_argument('--attention', type=str)

    # Misc Parameters
    parser.add_argument('--quiet', action='store_true', default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {'num_workers': args.nb_workers, 'pin_memory': True} if use_cuda else {}

    writer = SummaryWriter(logdir=os.path.join('tensorboard', args.tag))

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device("cuda" if use_cuda else "cpu")

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)


    train_sampler = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=data.collate_fn, drop_last=True,
        **dataloader_kwargs
    )
    valid_sampler = torch.utils.data.DataLoader(
        valid_dataset, batch_size=1, collate_fn=data.collate_fn, **dataloader_kwargs
    )

    if args.wst_model:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(
        valid_dataset.sample_rate, args.nfft, args.bandwidth
    )

    train_args_dict = vars(args)
    train_args_dict['max_bin'] = int(max_bin)  # added to config
    train_args_dict['vocabulary_size'] = valid_dataset.vocabulary_size  # added to config

    train_params_dict = copy.deepcopy(vars(args))  # return args as dictionary with no influence on args

    # add to parameters for model loading but not to config file
    train_params_dict['scaler_mean'] = scaler_mean
    train_params_dict['scaler_std'] = scaler_std

    model_class = model_utls.ModelLoader.get_model(args.architecture)
    model_to_train = model_class.from_config(train_params_dict)
    model_to_train.to(device)

    optimizer = torch.optim.Adam(
        model_to_train.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10
    )

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.wst_model:
        model_path = Path(os.path.join('trained_models', args.wst_model)).expanduser()
        with open(Path(model_path, args.target + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path, args.target + ".chkpnt")
        checkpoint = torch.load(target_model_path, map_location=device)


        model_to_train.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        # train for another arg.epochs
        t = tqdm.trange(
            results['epochs_trained'],
            results['epochs_trained'] + args.epochs + 1,
            disable=args.quiet
        )
        train_losses = results['train_loss_history']
        valid_losses = results['valid_loss_history']
        train_times = results['train_time_history']
        best_epoch = 0

    # else start from 0
    else:
        t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
        train_losses = []
        valid_losses = []
        train_times = []
        best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()

        train_loss = train(args, model_to_train, device, train_sampler, optimizer)
        #valid_loss, sdr_val, sar_val, sir_val = valid(args, model_to_train, device, valid_sampler)
        valid_loss = valid(args, model_to_train, device, valid_sampler)

        writer.add_scalar("Training_cost", train_loss, epoch)
        writer.add_scalar("Validation_cost", valid_loss, epoch)

        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        t.set_postfix(
            train_loss=train_loss, val_loss=valid_loss
        )

        stop = es.step(valid_loss)

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model_to_train.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target
        )

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'valid_loss_history': valid_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs
        }

        with open(Path(target_path,  args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break