Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    # =============================================================================
    #     parser.add_argument('--target', type=str, default='vocals',
    #                         help='target source (will be passed to the dataset)')
    #
    # =============================================================================
    parser.add_argument('--target',
                        type=str,
                        default='tabla',
                        help='target source (will be passed to the dataset)')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="aligned",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')
    parser.add_argument('--root',
                        type=str,
                        help='root path of dataset',
                        default='../rec_data_final/')
    parser.add_argument('--output',
                        type=str,
                        default="../new_models/model_tabla_mtl_ourmix_1",
                        help='provide output path base folder name')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_mse_pretrain1')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default="../out_unmix/model_new_data_aug_tabla_mse_pretrain8" )
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_bce_finetune2')
    parser.add_argument('--model', type=str, help='Path to checkpoint folder')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='umxhq')
    parser.add_argument(
        '--onset-model',
        type=str,
        help='Path to onset detection model weights',
        default=
        "/media/Sharedata/rohit/cnn-onset-det/models/apr4/saved_model_0_80mel-0-16000_1ch_44100.pt"
    )

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.00001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weighting of different loss components')
    parser.add_argument(
        '--finetune',
        type=int,
        default=0,
        help=
        'If true(1), then optimiser states from checkpoint model are reset (required for bce finetuning), false if aim is to resume training from where it was left off'
    )
    parser.add_argument('--onset-thresh',
                        type=float,
                        default=0.3,
                        help='Threshold above which onset is said to occur')
    parser.add_argument(
        '--binarise',
        type=int,
        default=0,
        help=
        'If=1(true), then target novelty function is made binary, if=0(false), then left as it is'
    )
    parser.add_argument(
        '--onset-trainable',
        type=int,
        default=0,
        help=
        'If=1(true), then onsetCNN will also get trained in finetuning stage, if=0(false) then kept fixed'
    )

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')

    # =============================================================================
    #     parser.add_argument('--nfft', type=int, default=2048,
    #                         help='STFT fft size and window size')
    #     parser.add_argument('--nhop', type=int, default=512,
    #                         help='STFT hop size')
    # =============================================================================

    parser.add_argument('--n-mels',
                        type=int,
                        default=80,
                        help='Number of bins in mel spectrogram')

    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=4,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    torch.autograd.set_detect_anomaly(True)

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)
    print("TRAIN DATASET", train_dataset)
    print("VALID DATASET", valid_dataset)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **dataloader_kwargs)
    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    if args.model:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model_mtl.OpenUnmix_mtl(
        input_mean=scaler_mean,
        input_scale=scaler_std,
        nb_channels=args.nb_channels,
        hidden_size=args.hidden_size,
        n_fft=args.nfft,
        n_hop=args.nhop,
        max_bin=max_bin,
        sample_rate=train_dataset.sample_rate).to(device)

    #Read trained onset detection network (Model through which target spectrogram is passed)
    detect_onset = model.onsetCNN().to(device)
    detect_onset.load_state_dict(
        torch.load(args.onset_model, map_location='cuda:0'))

    #Model through which separated output is passed
    # detect_onset_training = model.onsetCNN().to(device)
    # detect_onset_training.load_state_dict(torch.load(args.onset_model, map_location='cuda:0'))

    for child in detect_onset.children():
        for param in child.parameters():
            param.requires_grad = False

    #If onset trainable is false, then we want to keep the weights of this moel fixed
    # if (args.onset_trainable == 0):
    #     for child in detect_onset_training.children():
    #         for param in child.parameters():
    #             param.requires_grad = False

    # #FOR CHECKING, REMOVE LATER
    # for child in detect_onset_training.children():
    #     for param in child.parameters():
    #         print(param.requires_grad)

    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        model_path = Path(args.model).expanduser()
        with open(Path(model_path, args.target + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path, args.target + ".chkpnt")
        checkpoint = torch.load(target_model_path, map_location=device)
        unmix.load_state_dict(checkpoint['state_dict'])

        #Only when onse is trainable and when that finetuning is being resumed from a point where it is left off, then read the onset state_dict
        # if ((args.onset_trainable==1)and(args.finetune==0)):
        #     detect_onset_training.load_state_dict(checkpoint['onset_state_dict'])
        #     print("Reading saved onset model")
        # else:
        #     print("Not reading saved onset model")

        if (args.finetune == 0):
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            # train for another epochs_trained
            t = tqdm.trange(results['epochs_trained'],
                            results['epochs_trained'] + args.epochs + 1,
                            disable=args.quiet)
            print("PICKUP WHERE LEFT OFF", args.finetune)
            train_losses = results['train_loss_history']
            train_mse_losses = results['train_mse_loss_history']
            train_bce_losses = results['train_bce_loss_history']
            valid_losses = results['valid_loss_history']
            valid_mse_losses = results['valid_mse_loss_history']
            valid_bce_losses = results['valid_bce_loss_history']
            train_times = results['train_time_history']
            best_epoch = results['best_epoch']

            es.best = results['best_loss']
            es.num_bad_epochs = results['num_bad_epochs']

        else:
            t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
            train_losses = []
            train_mse_losses = []
            train_bce_losses = []
            print("NOT PICKUP WHERE LEFT OFF", args.finetune)
            valid_losses = []
            valid_mse_losses = []
            valid_bce_losses = []

            train_times = []
            best_epoch = 0

        #es.best = results['best_loss']
        #es.num_bad_epochs = results['num_bad_epochs']
    # else start from 0
    else:
        t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
        train_losses = []
        train_mse_losses = []
        train_bce_losses = []

        valid_losses = []
        valid_mse_losses = []
        valid_bce_losses = []

        train_times = []
        best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss, train_mse_loss, train_bce_loss = train(
            args,
            unmix,
            device,
            train_sampler,
            optimizer,
            detect_onset=detect_onset)
        #train_mse_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[1]
        #train_bce_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[2]

        valid_loss, valid_mse_loss, valid_bce_loss = valid(
            args, unmix, device, valid_sampler, detect_onset=detect_onset)
        #valid_mse_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[1]
        #valid_bce_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[2]

        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        train_mse_losses.append(train_mse_loss)
        train_bce_losses.append(train_bce_loss)

        valid_losses.append(valid_loss)
        valid_mse_losses.append(valid_mse_loss)
        valid_bce_losses.append(valid_bce_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        #from matplotlib import pyplot as plt

        # =============================================================================
        #         plt.figure(figsize=(16,12))
        #         plt.subplot(2, 2, 1)
        #         plt.title("Training loss")
        #         plt.plot(train_losses,label="Training")
        #         plt.xlabel("Iterations")
        #         plt.ylabel("Loss")
        #         plt.legend()
        #         plt.show()
        #         #plt.savefig(Path(target_path, "train_plot.pdf"))
        #
        #         plt.figure(figsize=(16,12))
        #         plt.subplot(2, 2, 2)
        #         plt.title("Validation loss")
        #         plt.plot(valid_losses,label="Validation")
        #         plt.xlabel("Iterations")
        #         plt.ylabel("Loss")
        #         plt.legend()
        #         plt.show()
        #         #plt.savefig(Path(target_path, "val_plot.pdf"))
        # =============================================================================

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'onset_state_dict': detect_onset.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target)

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'train_mse_loss_history': train_mse_losses,
            'train_bce_loss_history': train_bce_losses,
            'valid_loss_history': valid_losses,
            'valid_mse_loss_history': valid_mse_losses,
            'valid_bce_loss_history': valid_bce_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path, args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break


# =============================================================================
#     plt.figure(figsize=(16,12))
#     plt.subplot(2, 2, 1)
#     plt.title("Training loss")
#     #plt.plot(train_losses,label="Training")
#     plt.plot(train_losses,label="Training")
#     plt.xlabel("Iterations")
#     plt.ylabel("Loss")
#     plt.legend()
#     #plt.show()
#
#     plt.figure(figsize=(16,12))
#     plt.subplot(2, 2, 2)
#     plt.title("Validation loss")
#     plt.plot(valid_losses,label="Validation")
#     plt.xlabel("Iterations")
#     plt.ylabel("Loss")
#     plt.legend()
#     plt.show()
#     plt.savefig(Path(target_path, "train_val_plot.pdf"))
#     #plt.savefig(Path(target_path, "train_plot.pdf"))
# =============================================================================

    print("TRAINING DONE!!")

    plt.figure()
    plt.title("Training loss")
    plt.plot(train_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_plot.pdf"))

    plt.figure()
    plt.title("Validation loss")
    plt.plot(valid_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_plot.pdf"))

    plt.figure()
    plt.title("Training BCE loss")
    plt.plot(train_bce_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_bce_plot.pdf"))

    plt.figure()
    plt.title("Validation BCE loss")
    plt.plot(valid_bce_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_bce_plot.pdf"))

    plt.figure()
    plt.title("Training MSE loss")
    plt.plot(train_mse_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_mse_plot.pdf"))

    plt.figure()
    plt.title("Validation MSE loss")
    plt.plot(valid_mse_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_mse_plot.pdf"))
Exemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    parser.add_argument('--target',
                        type=str,
                        default='vocals',
                        help='target source (will be passed to the dataset)')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="aligned",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')
    parser.add_argument('--root',
                        type=str,
                        help='root path of dataset',
                        default='../rec_data_new/')
    parser.add_argument('--output',
                        type=str,
                        default="../out_unmix/model_new_data_aug_tl",
                        help='provide output path base folder name')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder')
    parser.add_argument('--model',
                        type=str,
                        help='Path to checkpoint folder',
                        default='umxhq')

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.0000000001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')
    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=4,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)
    print("TRAIN DATASET", train_dataset)
    print("VALID DATASET", valid_dataset)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **dataloader_kwargs)
    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    # =============================================================================
    #     if args.model:
    #         scaler_mean = None
    #         scaler_std = None
    #
    #     else:
    # =============================================================================
    scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model.OpenUnmix(input_mean=scaler_mean,
                            input_scale=scaler_std,
                            nb_channels=args.nb_channels,
                            hidden_size=args.hidden_size,
                            n_fft=args.nfft,
                            n_hop=args.nhop,
                            max_bin=max_bin,
                            sample_rate=train_dataset.sample_rate).to(device)

    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        # disable progress bar
        err = io.StringIO()
        with redirect_stderr(err):
            unmix = torch.hub.load('sigsep/open-unmix-pytorch',
                                   'umxhq',
                                   target=args.target,
                                   device=device,
                                   pretrained=True)
# =============================================================================
#         model_path = Path(args.model).expanduser()
#         with open(Path(model_path, args.target + '.json'), 'r') as stream:
#             results = json.load(stream)
#
#         target_model_path = Path(model_path, args.target + ".chkpnt")
#         checkpoint = torch.load(target_model_path, map_location=device)
#         unmix.load_state_dict(checkpoint['state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer'])
#         scheduler.load_state_dict(checkpoint['scheduler'])
#         # train for another epochs_trained
#         t = tqdm.trange(
#             results['epochs_trained'],
#             results['epochs_trained'] + args.epochs + 1,
#             disable=args.quiet
#         )
#         train_losses = results['train_loss_history']
#         valid_losses = results['valid_loss_history']
#         train_times = results['train_time_history']
#         best_epoch = results['best_epoch']
#         es.best = results['best_loss']
#         es.num_bad_epochs = results['num_bad_epochs']
#     # else start from 0
# =============================================================================

    t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
    train_losses = []
    valid_losses = []
    train_times = []
    best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss = train(args, unmix, device, train_sampler, optimizer)
        valid_loss = valid(args, unmix, device, valid_sampler)
        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        from matplotlib import pyplot as plt

        plt.figure(figsize=(16, 12))
        plt.subplot(2, 2, 1)
        plt.title("Training loss")
        plt.plot(train_losses, label="Training")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        #plt.savefig(Path(target_path, "train_plot.pdf"))

        plt.figure(figsize=(16, 12))
        plt.subplot(2, 2, 2)
        plt.title("Validation loss")
        plt.plot(valid_losses, label="Validation")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        #plt.savefig(Path(target_path, "val_plot.pdf"))

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target)

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'valid_loss_history': valid_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path, args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    parser.add_argument('--target', type=str, default='vocals',
                        help='target source (will be passed to the dataset)')

    # experiment tag which will determine output folder in trained models, tensorboard name, etc.
    parser.add_argument('--tag', type=str)


    # allow to pass a comment about the experiment
    parser.add_argument('--comment', type=str, help='comment about the experiment')

    args, _ = parser.parse_known_args()

    # Dataset paramaters
    parser.add_argument('--dataset', type=str, default="musdb",
                        choices=[
                            'musdb_lyrics', 'timit_music', 'blended', 'nus', 'nus_train'
                        ],
                        help='Name of the dataset.')

    parser.add_argument('--root', type=str, help='root path of dataset')
    parser.add_argument('--output', type=str, default="trained_models/{}/".format(args.tag),
                        help='provide output path base folder name')

    parser.add_argument('--wst-model', type=str, help='Path to checkpoint folder for warmstart')

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument('--patience', type=int, default=140,
                        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience', type=int, default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma', type=float, default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay', type=float, default=0.00001,
                        help='weight decay')
    parser.add_argument('--seed', type=int, default=0, metavar='S',
                        help='random seed (default: 0)')

    parser.add_argument('--alignment-from', type=str, default=None)
    parser.add_argument('--fake-alignment', action='store_true', default=False)


    # Model Parameters
    parser.add_argument('--unidirectional', action='store_true', default=False,
                        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft', type=int, default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024,
                        help='STFT hop size')
    parser.add_argument('--hidden-size', type=int, default=512,
                        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth', type=int, default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels', type=int, default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers', type=int, default=0,
                        help='Number of workers for dataloader.')
    parser.add_argument('--nb-audio-encoder-layers', type=int, default=2)
    parser.add_argument('--nb-layers', type=int, default=3)
    # name of the model class in model.py that should be used
    parser.add_argument('--architecture', type=str)
    # select attention type if applicable for selected model
    parser.add_argument('--attention', type=str)

    # Misc Parameters
    parser.add_argument('--quiet', action='store_true', default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {'num_workers': args.nb_workers, 'pin_memory': True} if use_cuda else {}

    writer = SummaryWriter(logdir=os.path.join('tensorboard', args.tag))

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device("cuda" if use_cuda else "cpu")

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)


    train_sampler = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=data.collate_fn, drop_last=True,
        **dataloader_kwargs
    )
    valid_sampler = torch.utils.data.DataLoader(
        valid_dataset, batch_size=1, collate_fn=data.collate_fn, **dataloader_kwargs
    )

    if args.wst_model:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(
        valid_dataset.sample_rate, args.nfft, args.bandwidth
    )

    train_args_dict = vars(args)
    train_args_dict['max_bin'] = int(max_bin)  # added to config
    train_args_dict['vocabulary_size'] = valid_dataset.vocabulary_size  # added to config

    train_params_dict = copy.deepcopy(vars(args))  # return args as dictionary with no influence on args

    # add to parameters for model loading but not to config file
    train_params_dict['scaler_mean'] = scaler_mean
    train_params_dict['scaler_std'] = scaler_std

    model_class = model_utls.ModelLoader.get_model(args.architecture)
    model_to_train = model_class.from_config(train_params_dict)
    model_to_train.to(device)

    optimizer = torch.optim.Adam(
        model_to_train.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10
    )

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.wst_model:
        model_path = Path(os.path.join('trained_models', args.wst_model)).expanduser()
        with open(Path(model_path, args.target + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path, args.target + ".chkpnt")
        checkpoint = torch.load(target_model_path, map_location=device)


        model_to_train.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        # train for another arg.epochs
        t = tqdm.trange(
            results['epochs_trained'],
            results['epochs_trained'] + args.epochs + 1,
            disable=args.quiet
        )
        train_losses = results['train_loss_history']
        valid_losses = results['valid_loss_history']
        train_times = results['train_time_history']
        best_epoch = 0

    # else start from 0
    else:
        t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
        train_losses = []
        valid_losses = []
        train_times = []
        best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()

        train_loss = train(args, model_to_train, device, train_sampler, optimizer)
        #valid_loss, sdr_val, sar_val, sir_val = valid(args, model_to_train, device, valid_sampler)
        valid_loss = valid(args, model_to_train, device, valid_sampler)

        writer.add_scalar("Training_cost", train_loss, epoch)
        writer.add_scalar("Validation_cost", valid_loss, epoch)

        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        t.set_postfix(
            train_loss=train_loss, val_loss=valid_loss
        )

        stop = es.step(valid_loss)

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model_to_train.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target
        )

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'valid_loss_history': valid_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs
        }

        with open(Path(target_path,  args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break