Пример #1
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('--gpu', '-g', type=int, default=-1)
    p.add_argument('--seed', '-s', type=int, default=2019)
    p.add_argument('--sr', '-r', type=int, default=44100)
    p.add_argument('--hop_length', '-l', type=int, default=1024)
    p.add_argument('--mixture_dataset', '-m', required=True)
    p.add_argument('--instrumental_dataset', '-i', required=True)
    p.add_argument('--validation_rate', '-v', type=float, default=0.1)
    p.add_argument('--learning_rate', type=float, default=0.001)
    p.add_argument('--lr_min', type=float, default=0.0001)
    p.add_argument('--lr_decay', type=float, default=0.9)
    p.add_argument('--lr_decay_interval', type=int, default=6)
    p.add_argument('--batchsize', '-B', type=int, default=8)
    p.add_argument('--val_batchsize', '-b', type=int, default=8)
    p.add_argument('--val_filelist', '-V', type=str, default=None)
    p.add_argument('--cropsize', '-c', type=int, default=448)
    p.add_argument('--val_cropsize', '-C', type=int, default=896)
    p.add_argument('--patches', '-p', type=int, default=16)
    p.add_argument('--epoch', '-E', type=int, default=100)
    p.add_argument('--inner_epoch', '-e', type=int, default=4)
    p.add_argument('--oracle_rate', '-O', type=float, default=0)
    p.add_argument('--oracle_drop_rate', '-o', type=float, default=0.5)
    p.add_argument('--mixup', '-M', action='store_true')
    p.add_argument('--mixup_alpha', '-a', type=float, default=1.0)
    p.add_argument('--pretrained_model', '-P', type=str, default=None)
    args = p.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    if chainer.backends.cuda.available:
        chainer.backends.cuda.cupy.random.seed(args.seed)
        chainer.backends.cuda.set_max_workspace_size(512 * 1024 * 1024)
    chainer.global_config.autotune = True
    timestamp = dt.now().strftime('%Y%m%d%H%M%S')

    model = unet.MultiBandUNet()
    if args.pretrained_model is not None:
        chainer.serializers.load_npz(args.pretrained_model, model)
    if args.gpu >= 0:
        chainer.backends.cuda.check_cuda_available()
        chainer.backends.cuda.get_device(args.gpu).use()
        model.to_gpu()

    optimizer = chainer.optimizers.Adam(args.learning_rate)
    optimizer.setup(model)

    train_filelist, val_filelist = train_val_split(
        mix_dir=args.mixture_dataset,
        inst_dir=args.instrumental_dataset,
        val_rate=args.validation_rate,
        val_filelist_json=args.val_filelist)

    with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
        json.dump(val_filelist, f, ensure_ascii=False)

    for i, (X_fname, y_fname) in enumerate(val_filelist):
        print(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))

    X_valid, y_valid = dataset.make_validation_set(val_filelist,
                                                   args.val_cropsize,
                                                   model.offset, args.sr,
                                                   args.hop_length)

    log = []
    oracle_X = None
    oracle_y = None
    best_count = 0
    best_loss = np.inf
    for epoch in range(args.epoch):
        X_train, y_train = dataset.make_training_set(train_filelist,
                                                     args.cropsize,
                                                     args.patches, args.sr,
                                                     args.hop_length)

        if args.mixup:
            X_train, y_train = dataset.mixup_generator(X_train, y_train,
                                                       args.mixup_alpha)

        if oracle_X is not None and oracle_y is not None:
            perm = np.random.permutation(len(oracle_X))
            X_train[perm] = oracle_X
            y_train[perm] = oracle_y

        print('# epoch', epoch)
        instance_loss = np.zeros(len(X_train), dtype=np.float32)
        for inner_epoch in range(args.inner_epoch):
            print('  * inner epoch {}'.format(inner_epoch))
            train_loss = train_inner_epoch(X_train, y_train, model, optimizer,
                                           args.batchsize, instance_loss)
            valid_loss = valid_inner_epoch(X_valid, y_valid, model,
                                           args.val_batchsize)

            print('    * training loss = {:.6f}, validation loss = {:.6f}'.
                  format(train_loss * 1000, valid_loss * 1000))

            log.append([train_loss, valid_loss])
            np.save('log_{}.npy'.format(timestamp), np.asarray(log))

            best_count += 1
            if valid_loss < best_loss:
                best_count = 0
                best_loss = valid_loss
                print('    * best validation loss')
                model_path = 'models/model_iter{}.npz'.format(epoch)
                chainer.serializers.save_npz(model_path, model)

            if epoch > 1 and best_count >= args.lr_decay_interval:
                best_count = 0
                optimizer.alpha *= args.lr_decay
                optimizer.alpha = max(optimizer.alpha, args.lr_min)
                print('    * learning rate decay: {:.6f}'.format(
                    optimizer.alpha))

        if args.oracle_rate > 0:
            instance_loss /= args.inner_epoch
            oracle_X, oracle_y, idx = dataset.get_oracle_data(
                X_train, y_train, instance_loss, args.oracle_rate,
                args.oracle_drop_rate)
            print('  * oracle loss = {:.6f}'.format(instance_loss[idx].mean()))

        del X_train, y_train
        gc.collect()
Пример #2
0
        offset=model.offset)

    log = []
    oracle_X = None
    oracle_y = None
    best_count = 0
    best_loss = np.inf
    for epoch in range(args.epoch):
        X_train, y_train = dataset.create_training_set(
            filelist=train_filelist,
            cropsize=args.cropsize,
            patches=args.patches,
            sr=args.sr,
            hop_length=args.hop_length)
        if args.mixup:
            X_train, y_train = dataset.mixup_generator(X_train, y_train,
                                                       args.mixup_alpha)
        if oracle_X is not None and oracle_y is not None:
            perm = np.random.permutation(len(oracle_X))
            X_train[perm] = oracle_X
            y_train[perm] = oracle_y
        print('# epoch', epoch)
        instance_loss = np.zeros(len(X_train), dtype=np.float32)
        for inner_epoch in range(args.inner_epoch):
            sum_loss = 0
            best_count += 1
            perm = np.random.permutation(len(X_train))
            print('  * inner epoch {}'.format(inner_epoch))
            for i in range(0, len(X_train), args.batchsize):
                local_perm = perm[i:i + args.batchsize]
                X_batch = xp.asarray(X_train[local_perm])
                y_batch = xp.asarray(y_train[local_perm])
Пример #3
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('--gpu', '-g', type=int, default=-1)
    p.add_argument('--seed', '-s', type=int, default=2019)
    p.add_argument('--sr', '-r', type=int, default=44100)
    p.add_argument('--hop_length', '-l', type=int, default=1024)
    p.add_argument('--mixtures', '-m', required=True)
    p.add_argument('--instruments', '-i', required=True)
    p.add_argument('--learning_rate', type=float, default=0.001)
    p.add_argument('--lr_min', type=float, default=0.0001)
    p.add_argument('--lr_decay_factor', type=float, default=0.9)
    p.add_argument('--lr_decay_patience', type=int, default=6)
    p.add_argument('--batchsize', '-B', type=int, default=4)
    p.add_argument('--cropsize', '-c', type=int, default=256)
    p.add_argument('--val_rate', '-v', type=float, default=0.1)
    p.add_argument('--val_filelist', '-V', type=str, default=None)
    p.add_argument('--val_batchsize', '-b', type=int, default=4)
    p.add_argument('--val_cropsize', '-C', type=int, default=512)
    p.add_argument('--patches', '-p', type=int, default=16)
    p.add_argument('--epoch', '-E', type=int, default=80)
    p.add_argument('--inner_epoch', '-e', type=int, default=4)
    p.add_argument('--oracle_rate', '-O', type=float, default=0)
    p.add_argument('--oracle_drop_rate', '-o', type=float, default=0.5)
    p.add_argument('--mixup_rate', '-M', type=float, default=0.0)
    p.add_argument('--mixup_alpha', '-a', type=float, default=1.0)
    p.add_argument('--pretrained_model', '-P', type=str, default=None)
    p.add_argument('--debug', '-d', action='store_true')
    args = p.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    timestamp = dt.now().strftime('%Y%m%d%H%M%S')

    model = nets.CascadedASPPNet()
    if args.pretrained_model is not None:
        model.load_state_dict(torch.load(args.pretrained_model))
    if args.gpu >= 0:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_factor,
        patience=args.lr_decay_patience,
        min_lr=args.lr_min,
        verbose=True)

    train_filelist, val_filelist = train_val_split(
        mix_dir=args.mixtures,
        inst_dir=args.instruments,
        val_rate=args.val_rate,
        val_filelist_json=args.val_filelist)

    if args.debug:
        print('### DEBUG MODE')
        train_filelist = train_filelist[:1]
        val_filelist = val_filelist[:1]

    with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
        json.dump(val_filelist, f, ensure_ascii=False)

    for i, (X_fname, y_fname) in enumerate(val_filelist):
        print(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))

    val_dataset = dataset.make_validation_set(filelist=val_filelist,
                                              cropsize=args.val_cropsize,
                                              sr=args.sr,
                                              hop_length=args.hop_length,
                                              offset=model.offset)
    val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                 batch_size=args.val_batchsize,
                                                 shuffle=False,
                                                 num_workers=4)

    log = []
    oracle_X = None
    oracle_y = None
    best_loss = np.inf
    for epoch in range(args.epoch):
        X_train, y_train = dataset.make_training_set(train_filelist,
                                                     args.cropsize,
                                                     args.patches, args.sr,
                                                     args.hop_length,
                                                     model.offset)

        X_train, y_train = dataset.mixup_generator(X_train, y_train,
                                                   args.mixup_rate,
                                                   args.mixup_alpha)

        if oracle_X is not None and oracle_y is not None:
            perm = np.random.permutation(len(X_train))[:len(oracle_X)]
            X_train[perm] = oracle_X
            y_train[perm] = oracle_y

        print('# epoch', epoch)
        for inner_epoch in range(args.inner_epoch):
            print('  * inner epoch {}'.format(inner_epoch))
            train_loss, instance_loss = train_inner_epoch(
                X_train, y_train, model, optimizer, args.batchsize)
            val_loss = val_inner_epoch(val_dataloader, model)

            print('    * training loss = {:.6f}, validation loss = {:.6f}'.
                  format(train_loss * 1000, val_loss * 1000))

            scheduler.step(val_loss)

            if val_loss < best_loss:
                best_loss = val_loss
                print('    * best validation loss')
                model_path = 'models/model_iter{}.pth'.format(epoch)
                torch.save(model.state_dict(), model_path)

            log.append([train_loss, val_loss])
            with open('log_{}.json'.format(timestamp), 'w',
                      encoding='utf8') as f:
                json.dump(log, f, ensure_ascii=False)

        if args.oracle_rate > 0:
            oracle_X, oracle_y, idx = dataset.get_oracle_data(
                X_train, y_train, instance_loss, args.oracle_rate,
                args.oracle_drop_rate)
            print('  * oracle loss = {:.6f}'.format(instance_loss[idx].mean() *
                                                    1000))

        del X_train, y_train
        gc.collect()
Пример #4
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--gpu", "-g", type=int, default=-1)
    p.add_argument("--seed", "-s", type=int, default=2019)
    p.add_argument("--sr", "-r", type=int, default=44100)
    p.add_argument("--hop_length", "-l", type=int, default=1024)
    p.add_argument("--mixtures", "-m", required=True)
    p.add_argument("--instruments", "-i", required=True)
    p.add_argument("--learning_rate", type=float, default=0.001)
    p.add_argument("--lr_min", type=float, default=0.0001)
    p.add_argument("--lr_decay_factor", type=float, default=0.9)
    p.add_argument("--lr_decay_patience", type=int, default=6)
    p.add_argument("--batchsize", "-B", type=int, default=4)
    p.add_argument("--cropsize", "-c", type=int, default=256)
    p.add_argument("--val_rate", "-v", type=float, default=0.1)
    p.add_argument("--val_filelist", "-V", type=str, default=None)
    p.add_argument("--val_batchsize", "-b", type=int, default=4)
    p.add_argument("--val_cropsize", "-C", type=int, default=512)
    p.add_argument("--patches", "-p", type=int, default=16)
    p.add_argument("--epoch", "-E", type=int, default=80)
    p.add_argument("--inner_epoch", "-e", type=int, default=4)
    p.add_argument("--oracle_rate", "-O", type=float, default=0)
    p.add_argument("--oracle_drop_rate", "-o", type=float, default=0.5)
    p.add_argument("--mixup_rate", "-M", type=float, default=0.0)
    p.add_argument("--mixup_alpha", "-a", type=float, default=1.0)
    p.add_argument("--pretrained_model", "-P", type=str, default=None)
    p.add_argument("--debug", "-d", action="store_true")
    args = p.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    timestamp = dt.now().strftime("%Y%m%d%H%M%S")

    model = nets.CascadedASPPNet()
    if args.pretrained_model is not None:
        model.load_state_dict(torch.load(args.pretrained_model))
    if args.gpu >= 0:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_factor,
        patience=args.lr_decay_patience,
        min_lr=args.lr_min,
        verbose=True,
    )

    train_filelist, val_filelist = train_val_split(
        mix_dir=args.mixtures,
        inst_dir=args.instruments,
        val_rate=args.val_rate,
        val_filelist_json=args.val_filelist,
    )

    if args.debug:
        print("### DEBUG MODE")
        train_filelist = train_filelist[:1]
        val_filelist = val_filelist[:1]

    with open("val_{}.json".format(timestamp), "w", encoding="utf8") as f:
        json.dump(val_filelist, f, ensure_ascii=False)

    for i, (X_fname, y_fname) in enumerate(val_filelist):
        print(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))

    val_dataset = dataset.make_validation_set(
        filelist=val_filelist,
        cropsize=args.val_cropsize,
        sr=args.sr,
        hop_length=args.hop_length,
        offset=model.offset,
    )
    val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                 batch_size=args.val_batchsize,
                                                 shuffle=False,
                                                 num_workers=4)

    log = []
    oracle_X = None
    oracle_y = None
    best_loss = np.inf
    for epoch in range(args.epoch):
        X_train, y_train = dataset.make_training_set(
            train_filelist,
            args.cropsize,
            args.patches,
            args.sr,
            args.hop_length,
            model.offset,
        )

        X_train, y_train = dataset.mixup_generator(X_train, y_train,
                                                   args.mixup_rate,
                                                   args.mixup_alpha)

        if oracle_X is not None and oracle_y is not None:
            perm = np.random.permutation(len(X_train))[:len(oracle_X)]
            X_train[perm] = oracle_X
            y_train[perm] = oracle_y

        print("# epoch", epoch)
        for inner_epoch in range(args.inner_epoch):
            print("  * inner epoch {}".format(inner_epoch))
            train_loss, instance_loss = train_inner_epoch(
                X_train, y_train, model, optimizer, args.batchsize)
            val_loss = val_inner_epoch(val_dataloader, model)

            print("    * training loss = {:.6f}, validation loss = {:.6f}".
                  format(train_loss * 1000, val_loss * 1000))

            scheduler.step(val_loss)

            if val_loss < best_loss:
                best_loss = val_loss
                print("    * best validation loss")
                model_path = "models/model_iter{}.pth".format(epoch)
                torch.save(model.state_dict(), model_path)

            log.append([train_loss, val_loss])
            with open("log_{}.json".format(timestamp), "w",
                      encoding="utf8") as f:
                json.dump(log, f, ensure_ascii=False)

        if args.oracle_rate > 0:
            oracle_X, oracle_y, idx = dataset.get_oracle_data(
                X_train, y_train, instance_loss, args.oracle_rate,
                args.oracle_drop_rate)
            print("  * oracle loss = {:.6f}".format(instance_loss[idx].mean() *
                                                    1000))

        del X_train, y_train
        gc.collect()