Esempio n. 1
0
def test_stft(audio, nb_channels, nfft, hop):
    unmix = model.OpenUnmix(nb_channels=nb_channels)
    unmix.stft.center = True
    X = unmix.stft(audio)
    X = X.detach().numpy()
    X_complex_np = X[..., 0] + X[..., 1]*1j
    out = test.istft(X_complex_np)
    assert np.sqrt(np.mean((audio.detach().numpy() - out)**2)) < 1e-6
Esempio n. 2
0
def load_model(target, model_name='umxhq', device='cpu'):
    """
    target model path can be either <target>.pth, or <target>-sha256.pth
    (as used on torchub)
    """
    model_path = Path(model_name).expanduser()
    if not model_path.exists():
        # model path does not exist, use hubconf model
        try:
            # disable progress bar
            err = io.StringIO()
            with redirect_stderr(err):
                return torch.hub.load(
                    'sigsep/open-unmix-pytorch',
                    model_name,
                    target=target,
                    device=device,
                    pretrained=True
                )
            print(err.getvalue())
        except AttributeError:
            raise NameError('Model does not exist on torchhub')
            # assume model is a path to a local model_name direcotry
    else:
        # load model from disk
        with open(Path(model_path, target + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = next(Path(model_path).glob("%s*.pth" % target))
        state = torch.load(
            target_model_path,
            map_location=device
        )

        max_bin = utils.bandwidth_to_max_bin(
            state['sample_rate'],
            results['args']['nfft'],
            results['args']['bandwidth']
        )

        unmix = model.OpenUnmix(
            n_fft=results['args']['nfft'],
            n_hop=results['args']['nhop'],
            nb_channels=results['args']['nb_channels'],
            hidden_size=results['args']['hidden_size'],
            max_bin=max_bin
        )

        unmix.load_state_dict(state)
        unmix.stft.center = True
        unmix.eval()
        unmix.to(device)
        return unmix
Esempio n. 3
0
def test_shape(audio, nb_channels, unidirectional, nb_layers, hidden_size,
               n_fft, n_hop):
    unmix = model.OpenUnmix(n_fft=n_fft,
                            n_hop=n_hop,
                            nb_channels=nb_channels,
                            input_is_spectrogram=True,
                            unidirectional=unidirectional,
                            nb_layers=nb_layers,
                            hidden_size=hidden_size)
    unmix.eval()
    spec = torch.nn.Sequential(unmix.stft, unmix.spec)
    X = spec(audio)
    Y = unmix(X)
    assert X.shape == Y.shape
Esempio n. 4
0
def load_model(target, model_name='umxhq', device='cpu', chkpnt=False):
    """
    target model path can be either <target>.pth, or <target>-sha256.pth
    (as used on torchub)
    """
    model_path = Path(model_name).expanduser()
    if not model_path.exists():
        print("Can't find model! Please check model_path.")

    else:
        # load model from disk
        with open(Path(model_path, target + '.json'), 'r') as stream:
            results = json.load(stream)
        if not chkpnt:
            target_model_path = next(Path(model_path).glob("%s*.pth" % target))
            state = torch.load(target_model_path, map_location=device)
        else:  # using chkpnt instead of pth
            target_model_path = next(
                Path(model_path).glob("%s*.chkpnt" % target))
            state = torch.load(target_model_path,
                               #                    map_location=device
                               )['state_dict']

        max_bin = utils.bandwidth_to_max_bin(state['sample_rate'],
                                             results['args']['nfft'],
                                             results['args']['bandwidth'])

        unmix = model.OpenUnmix(n_fft=results['args']['nfft'],
                                n_hop=results['args']['nhop'],
                                nb_channels=results['args']['nb_channels'],
                                hidden_size=results['args']['hidden_size'],
                                max_bin=max_bin)

        unmix.load_state_dict(state)
        unmix.stft.center = True
        unmix.eval()
        unmix.to(device)
        return unmix
Esempio n. 5
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    parser.add_argument('--target',
                        type=str,
                        default='vocals',
                        help='target source (will be passed to the dataset)')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="aligned",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')
    parser.add_argument('--root',
                        type=str,
                        help='root path of dataset',
                        default='../rec_data_new/')
    parser.add_argument('--output',
                        type=str,
                        default="../out_unmix/model_new_data_aug_tl",
                        help='provide output path base folder name')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder')
    parser.add_argument('--model',
                        type=str,
                        help='Path to checkpoint folder',
                        default='umxhq')

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.0000000001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')
    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=4,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)
    print("TRAIN DATASET", train_dataset)
    print("VALID DATASET", valid_dataset)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **dataloader_kwargs)
    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    # =============================================================================
    #     if args.model:
    #         scaler_mean = None
    #         scaler_std = None
    #
    #     else:
    # =============================================================================
    scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model.OpenUnmix(input_mean=scaler_mean,
                            input_scale=scaler_std,
                            nb_channels=args.nb_channels,
                            hidden_size=args.hidden_size,
                            n_fft=args.nfft,
                            n_hop=args.nhop,
                            max_bin=max_bin,
                            sample_rate=train_dataset.sample_rate).to(device)

    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        # disable progress bar
        err = io.StringIO()
        with redirect_stderr(err):
            unmix = torch.hub.load('sigsep/open-unmix-pytorch',
                                   'umxhq',
                                   target=args.target,
                                   device=device,
                                   pretrained=True)
# =============================================================================
#         model_path = Path(args.model).expanduser()
#         with open(Path(model_path, args.target + '.json'), 'r') as stream:
#             results = json.load(stream)
#
#         target_model_path = Path(model_path, args.target + ".chkpnt")
#         checkpoint = torch.load(target_model_path, map_location=device)
#         unmix.load_state_dict(checkpoint['state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer'])
#         scheduler.load_state_dict(checkpoint['scheduler'])
#         # train for another epochs_trained
#         t = tqdm.trange(
#             results['epochs_trained'],
#             results['epochs_trained'] + args.epochs + 1,
#             disable=args.quiet
#         )
#         train_losses = results['train_loss_history']
#         valid_losses = results['valid_loss_history']
#         train_times = results['train_time_history']
#         best_epoch = results['best_epoch']
#         es.best = results['best_loss']
#         es.num_bad_epochs = results['num_bad_epochs']
#     # else start from 0
# =============================================================================

    t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
    train_losses = []
    valid_losses = []
    train_times = []
    best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss = train(args, unmix, device, train_sampler, optimizer)
        valid_loss = valid(args, unmix, device, valid_sampler)
        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        from matplotlib import pyplot as plt

        plt.figure(figsize=(16, 12))
        plt.subplot(2, 2, 1)
        plt.title("Training loss")
        plt.plot(train_losses, label="Training")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        #plt.savefig(Path(target_path, "train_plot.pdf"))

        plt.figure(figsize=(16, 12))
        plt.subplot(2, 2, 2)
        plt.title("Validation loss")
        plt.plot(valid_losses, label="Validation")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        #plt.savefig(Path(target_path, "val_plot.pdf"))

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target)

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'valid_loss_history': valid_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path, args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break
Esempio n. 6
0
def train():
    parser, args = get_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Initialize DataIterator for MNIST.
    train_source, valid_source, args = data.load_datasources(
        parser, args, rng=RandomState(42))

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    scaler_mean, scaler_std = get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model.OpenUnmix(input_mean=scaler_mean,
                            input_scale=scaler_std,
                            nb_channels=args.nb_channels,
                            hidden_size=args.hidden_size,
                            n_fft=args.nfft,
                            n_hop=args.nhop,
                            max_bin=max_bin,
                            sample_rate=train_source.sample_rate)

    # Create input variables.
    audio_shape = [args.batch_size] + list(train_source._get_data(0)[0].shape)
    mixture_audio = nn.Variable(audio_shape)
    target_audio = nn.Variable(audio_shape)

    vmixture_audio = nn.Variable(audio_shape)
    vtarget_audio = nn.Variable(audio_shape)

    # create train graph
    pred_spec = unmix(mixture_audio, test=False)
    pred_spec.persistent = True

    target_spec = model.Spectrogram(*model.STFT(target_audio,
                                                n_fft=unmix.n_fft,
                                                n_hop=unmix.n_hop),
                                    mono=(unmix.nb_channels == 1))

    loss = F.mean(F.squared_error(pred_spec, target_spec), axis=1)

    # Create Solver.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # Training loop.
    t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
    es = utils.EarlyStopping(patience=args.patience)

    for epoch in t:
        # TRAINING
        t.set_description("Training Epoch")
        b = tqdm.trange(0,
                        train_source._size // args.batch_size,
                        disable=args.quiet)
        losses = utils.AverageMeter()
        for batch in b:
            mixture_audio.d, target_audio.d = train_iter.next()
            b.set_description("Training Batch")
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.weight_decay(args.weight_decay)
            solver.update()
            losses.update(loss.d.copy().mean())
            b.set_postfix(train_loss=losses.avg)

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(valid_source._size):
            # Create new validation input variables for every batch
            vmixture_audio.d, vtarget_audio.d = valid_iter.next()
            # create validation graph
            vpred_spec = unmix(vmixture_audio, test=True)
            vpred_spec.persistent = True

            vtarget_spec = model.Spectrogram(*model.STFT(vtarget_audio,
                                                         n_fft=unmix.n_fft,
                                                         n_hop=unmix.n_hop),
                                             mono=(unmix.nb_channels == 1))
            vloss = F.mean(F.squared_error(vpred_spec, vtarget_spec), axis=1)

            vloss.forward(clear_buffer=True)
            vlosses.update(vloss.d.copy().mean())

        t.set_postfix(train_loss=losses.avg, val_loss=vlosses.avg)

        stop = es.step(vlosses.avg)
        is_best = vlosses.avg == es.best

        # save current model
        nn.save_parameters(
            os.path.join(args.output, 'checkpoint_%s.h5' % args.target))

        if is_best:
            best_epoch = epoch
            nn.save_parameters(os.path.join(args.output,
                                            '%s.h5' % args.target))

        if stop:
            print("Apply Early Stopping")
            break
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    # =============================================================================
    #     parser.add_argument('--target', type=str, default='vocals',
    #                         help='target source (will be passed to the dataset)')
    #
    # =============================================================================
    parser.add_argument('--target',
                        type=str,
                        default='tabla',
                        help='target source (will be passed to the dataset)')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="aligned",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')
    parser.add_argument('--root',
                        type=str,
                        help='root path of dataset',
                        default='../rec_data_final/')
    parser.add_argument('--output',
                        type=str,
                        default="../new_models/model_tabla_mse_pretrain1",
                        help='provide output path base folder name')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_mse_pretrain1')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default="../out_unmix/model_new_data_aug_tabla_mse_pretrain8" )
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_bce_finetune2')
    parser.add_argument('--model', type=str, help='Path to checkpoint folder')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='umxhq')
    parser.add_argument(
        '--onset-model',
        type=str,
        help='Path to onset detection model weights',
        default=
        "/media/Sharedata/rohit/cnn-onset-det/models/apr4/saved_model_0_80mel-0-16000_1ch_44100.pt"
    )

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.00001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weighting of different loss components')
    parser.add_argument(
        '--finetune',
        type=int,
        default=0,
        help=
        'If true(1), then optimiser states from checkpoint model are reset (required for bce finetuning), false if aim is to resume training from where it was left off'
    )
    parser.add_argument('--onset-thresh',
                        type=float,
                        default=0.3,
                        help='Threshold above which onset is said to occur')
    parser.add_argument(
        '--binarise',
        type=int,
        default=0,
        help=
        'If=1(true), then target novelty function is made binary, if=0(false), then left as it is'
    )
    parser.add_argument(
        '--onset-trainable',
        type=int,
        default=0,
        help=
        'If=1(true), then onsetCNN will also get trained in finetuning stage, if=0(false) then kept fixed'
    )

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')

    # =============================================================================
    #     parser.add_argument('--nfft', type=int, default=2048,
    #                         help='STFT fft size and window size')
    #     parser.add_argument('--nhop', type=int, default=512,
    #                         help='STFT hop size')
    # =============================================================================

    parser.add_argument('--n-mels',
                        type=int,
                        default=80,
                        help='Number of bins in mel spectrogram')

    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=4,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    torch.autograd.set_detect_anomaly(True)

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)
    print("TRAIN DATASET", train_dataset)
    print("VALID DATASET", valid_dataset)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **dataloader_kwargs)
    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    if args.model:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model.OpenUnmix(input_mean=scaler_mean,
                            input_scale=scaler_std,
                            nb_channels=args.nb_channels,
                            hidden_size=args.hidden_size,
                            n_fft=args.nfft,
                            n_hop=args.nhop,
                            max_bin=max_bin,
                            sample_rate=train_dataset.sample_rate).to(device)

    #Read trained onset detection network (Model through which target spectrogra is passed)
    detect_onset = model.onsetCNN().to(device)
    detect_onset.load_state_dict(
        torch.load(args.onset_model, map_location='cuda:0'))

    #Model through which separated output is passed
    detect_onset_training = model.onsetCNN().to(device)
    detect_onset_training.load_state_dict(
        torch.load(args.onset_model, map_location='cuda:0'))

    for child in detect_onset.children():
        for param in child.parameters():
            param.requires_grad = False

    #If onset trainable is false, then we want to keep the weights of this model fixed
    if (args.onset_trainable == 0):
        for child in detect_onset_training.children():
            for param in child.parameters():
                param.requires_grad = False

    #FOR CHECKING, REMOVE LATER
    for child in detect_onset_training.children():
        for param in child.parameters():
            print(param.requires_grad)

    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        model_path = Path(args.model).expanduser()
        with open(Path(model_path, args.target + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path, args.target + ".chkpnt")
        checkpoint = torch.load(target_model_path, map_location=device)
        unmix.load_state_dict(checkpoint['state_dict'])

        #Only when onse is trainable and when that finetuning is being resumed from a point where it is left off, then read the onset state_dict
        if ((args.onset_trainable == 1) and (args.finetune == 0)):
            detect_onset_training.load_state_dict(
                checkpoint['onset_state_dict'])
            print("Reading saved onset model")
        else:
            print("Not reading saved onset model")

        if (args.finetune == 0):
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            # train for another epochs_trained
            t = tqdm.trange(results['epochs_trained'],
                            results['epochs_trained'] + args.epochs + 1,
                            disable=args.quiet)
            print("PICKUP WHERE LEFT OFF", args.finetune)
            train_losses = results['train_loss_history']
            train_mse_losses = results['train_mse_loss_history']
            train_bce_losses = results['train_bce_loss_history']
            valid_losses = results['valid_loss_history']
            valid_mse_losses = results['valid_mse_loss_history']
            valid_bce_losses = results['valid_bce_loss_history']
            train_times = results['train_time_history']
            best_epoch = results['best_epoch']

            es.best = results['best_loss']
            es.num_bad_epochs = results['num_bad_epochs']

        else:
            t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
            train_losses = []
            train_mse_losses = []
            train_bce_losses = []
            print("NOTTTTPICKUP WHERE LEFT OFF", args.finetune)
            valid_losses = []
            valid_mse_losses = []
            valid_bce_losses = []

            train_times = []
            best_epoch = 0

        #es.best = results['best_loss']
        #es.num_bad_epochs = results['num_bad_epochs']
    # else start from 0
    else:
        t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
        train_losses = []
        train_mse_losses = []
        train_bce_losses = []

        valid_losses = []
        valid_mse_losses = []
        valid_bce_losses = []

        train_times = []
        best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss, train_mse_loss, train_bce_loss = train(
            args,
            unmix,
            device,
            train_sampler,
            optimizer,
            detect_onset=detect_onset,
            detect_onset_training=detect_onset_training)
        #train_mse_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[1]
        #train_bce_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[2]

        valid_loss, valid_mse_loss, valid_bce_loss = valid(
            args,
            unmix,
            device,
            valid_sampler,
            detect_onset=detect_onset,
            detect_onset_training=detect_onset_training)
        #valid_mse_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[1]
        #valid_bce_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[2]

        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        train_mse_losses.append(train_mse_loss)
        train_bce_losses.append(train_bce_loss)

        valid_losses.append(valid_loss)
        valid_mse_losses.append(valid_mse_loss)
        valid_bce_losses.append(valid_bce_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        #from matplotlib import pyplot as plt

        # =============================================================================
        #         plt.figure(figsize=(16,12))
        #         plt.subplot(2, 2, 1)
        #         plt.title("Training loss")
        #         plt.plot(train_losses,label="Training")
        #         plt.xlabel("Iterations")
        #         plt.ylabel("Loss")
        #         plt.legend()
        #         plt.show()
        #         #plt.savefig(Path(target_path, "train_plot.pdf"))
        #
        #         plt.figure(figsize=(16,12))
        #         plt.subplot(2, 2, 2)
        #         plt.title("Validation loss")
        #         plt.plot(valid_losses,label="Validation")
        #         plt.xlabel("Iterations")
        #         plt.ylabel("Loss")
        #         plt.legend()
        #         plt.show()
        #         #plt.savefig(Path(target_path, "val_plot.pdf"))
        # =============================================================================

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'onset_state_dict': detect_onset_training.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target)

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'train_mse_loss_history': train_mse_losses,
            'train_bce_loss_history': train_bce_losses,
            'valid_loss_history': valid_losses,
            'valid_mse_loss_history': valid_mse_losses,
            'valid_bce_loss_history': valid_bce_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path, args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break


# =============================================================================
#     plt.figure(figsize=(16,12))
#     plt.subplot(2, 2, 1)
#     plt.title("Training loss")
#     #plt.plot(train_losses,label="Training")
#     plt.plot(train_losses,label="Training")
#     plt.xlabel("Iterations")
#     plt.ylabel("Loss")
#     plt.legend()
#     #plt.show()
#
#     plt.figure(figsize=(16,12))
#     plt.subplot(2, 2, 2)
#     plt.title("Validation loss")
#     plt.plot(valid_losses,label="Validation")
#     plt.xlabel("Iterations")
#     plt.ylabel("Loss")
#     plt.legend()
#     plt.show()
#     plt.savefig(Path(target_path, "train_val_plot.pdf"))
#     #plt.savefig(Path(target_path, "train_plot.pdf"))
# =============================================================================

    print("TRAINING DONE!!")

    plt.figure()
    plt.title("Training loss")
    plt.plot(train_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_plot.pdf"))

    plt.figure()
    plt.title("Validation loss")
    plt.plot(valid_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_plot.pdf"))

    plt.figure()
    plt.title("Training BCE loss")
    plt.plot(train_bce_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_bce_plot.pdf"))

    plt.figure()
    plt.title("Validation BCE loss")
    plt.plot(valid_bce_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_bce_plot.pdf"))

    plt.figure()
    plt.title("Training MSE loss")
    plt.plot(train_mse_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_mse_plot.pdf"))

    plt.figure()
    plt.title("Validation MSE loss")
    plt.plot(valid_mse_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_mse_plot.pdf"))
Esempio n. 8
0
def separate(audio, args):
    """
    Performing the separation on audio input
    Parameters
    ----------
    audio: np.ndarray [shape=(nb_timesteps, nb_channels)]
        mixture audio
    args : ArgumentParser
        ArgumentParser for OpenUnmix_CrossNet(X-UMX)/OpenUnmix(UMX) Inference

    Returns
    -------
    estimates: `dict` [`str`, `np.ndarray`]
        dictionary of all estimates as performed by the separation model.
    """

    # convert numpy audio to NNabla Variable
    audio_nn = nn.Variable.from_numpy_array(audio.T[None, ...])
    source_names = []
    V = []
    max_bin = bandwidth_to_max_bin(sample_rate=44100,
                                   n_fft=4096,
                                   bandwidth=16000)

    if not args.umx_infer:
        # Run X-UMX Inference
        nn.load_parameters(args.model)
        for j, target in enumerate(args.targets):
            if j == 0:
                unmix_target = model.OpenUnmix_CrossNet(max_bin=max_bin,
                                                        is_predict=True)
                mix_spec, msk, _ = unmix_target(audio_nn, test=True)
                # Network output is (nb_frames, nb_samples, nb_channels, nb_bins)
            V.append((msk[Ellipsis, j * 2:j * 2 + 2, :] * mix_spec).d[:, 0,
                                                                      ...])
            source_names += [target]
    else:
        # Run UMX Inference
        for j, target in enumerate(args.targets):
            with nn.parameter_scope(target):
                unmix_target = model.OpenUnmix(max_bin=max_bin)
                nn.load_parameters(f"{os.path.join(args.model, target)}.h5")
                # Network output is (nb_frames, nb_samples, nb_channels, nb_bins)
                V.append(unmix_target(audio_nn, test=True).d[:, 0, ...])
            source_names += [target]

    V = np.transpose(np.array(V), (1, 3, 2, 0))
    if args.softmask:
        # only exponentiate the model if we use softmask
        V = V**args.alpha

    real, imag = model.get_stft(audio_nn, center=True)

    # convert to complex numpy type
    X = real.d + imag.d * 1j
    X = X[0].transpose(2, 1, 0)

    if args.residual_model or len(args.targets) == 1:
        V = norbert.residual_model(V, X, args.alpha if args.softmask else 1)
        source_names += (['residual']
                         if len(args.targets) > 1 else ['accompaniment'])

    Y = norbert.wiener(V,
                       X.astype(np.complex128),
                       args.niter,
                       use_softmask=args.softmask)

    estimates = {}
    for j, name in enumerate(source_names):
        audio_hat = istft(Y[..., j].T,
                          n_fft=unmix_target.n_fft,
                          n_hopsize=unmix_target.n_hop)
        estimates[name] = audio_hat.T

    return estimates