Exemple #1
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('--gpu', '-g', type=int, default=-1)
    p.add_argument('--seed', '-s', type=int, default=2019)
    p.add_argument('--sr', '-r', type=int, default=44100)
    p.add_argument('--hop_length', '-H', type=int, default=1024)
    p.add_argument('--n_fft', '-f', type=int, default=2048)
    p.add_argument('--dataset', '-d', required=True)
    p.add_argument('--split_mode', '-S', type=str, choices=['random', 'subdirs'], default='random')
    p.add_argument('--learning_rate', '-l', type=float, default=0.001)
    p.add_argument('--lr_min', type=float, default=0.0001)
    p.add_argument('--lr_decay_factor', type=float, default=0.9)
    p.add_argument('--lr_decay_patience', type=int, default=6)
    p.add_argument('--batchsize', '-B', type=int, default=4)
    p.add_argument('--cropsize', '-c', type=int, default=256)
    p.add_argument('--patches', '-p', type=int, default=16)
    p.add_argument('--val_rate', '-v', type=float, default=0.2)
    p.add_argument('--val_filelist', '-V', type=str, default=None)
    p.add_argument('--val_batchsize', '-b', type=int, default=2)
    p.add_argument('--val_cropsize', '-C', type=int, default=512)
    p.add_argument('--epoch', '-E', type=int, default=60)
    p.add_argument('--inner_epoch', '-e', type=int, default=4)
    p.add_argument('--reduction_rate', '-R', type=float, default=0.0)
    p.add_argument('--reduction_level', '-L', type=float, default=0.2)
    p.add_argument('--mixup_rate', '-M', type=float, default=0.0)
    p.add_argument('--mixup_alpha', '-a', type=float, default=1.0)
    p.add_argument('--pretrained_model', '-P', type=str, default=None)
    p.add_argument('--debug', action='store_true')
    args = p.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    timestamp = datetime.now().strftime('%Y%m%d%H%M%S')

    val_filelist = []
    if args.val_filelist is not None:
        with open(args.val_filelist, 'r', encoding='utf8') as f:
            val_filelist = json.load(f)

    train_filelist, val_filelist = dataset.train_val_split(
        dataset_dir=args.dataset,
        split_mode=args.split_mode,
        val_rate=args.val_rate,
        val_filelist=val_filelist)

    if args.debug:
        print('### DEBUG MODE')
        train_filelist = train_filelist[:1]
        val_filelist = val_filelist[:1]
    elif args.val_filelist is None:
        with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
            json.dump(val_filelist, f, ensure_ascii=False)

    for i, (X_fname, y_fname) in enumerate(val_filelist):
        print(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))

    device = torch.device('cpu')
    model = nets.CascadedASPPNet(args.n_fft)
    if args.pretrained_model is not None:
        model.load_state_dict(torch.load(args.pretrained_model, map_location=device))
    if torch.cuda.is_available() and args.gpu >= 0:
        device = torch.device('cuda:{}'.format(args.gpu))
        model.to(device)

    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.learning_rate)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_factor,
        patience=args.lr_decay_patience,
        threshold=1e-6,
        min_lr=args.lr_min,
        verbose=True)

    val_dataset = dataset.make_validation_set(
        filelist=val_filelist,
        cropsize=args.val_cropsize,
        sr=args.sr,
        hop_length=args.hop_length,
        n_fft=args.n_fft,
        offset=model.offset)

    val_dataloader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=args.val_batchsize,
        shuffle=False,
        num_workers=4)

    bins = args.n_fft // 2 + 1
    freq_to_bin = 2 * bins / args.sr
    unstable_bins = int(160 * freq_to_bin)
    reduction_bins = int(16000 * freq_to_bin)
    reduction_mask = np.concatenate([
        np.linspace(0, 1, unstable_bins)[:, None],
        np.linspace(1, 0, reduction_bins - unstable_bins)[:, None],
        np.zeros((bins - reduction_bins, 1))
    ], axis=0) * args.reduction_level

    log = []
    best_loss = np.inf
    for epoch in range(args.epoch):
        X_train, y_train = dataset.make_training_set(
            filelist=train_filelist,
            cropsize=args.cropsize,
            patches=args.patches,
            sr=args.sr,
            hop_length=args.hop_length,
            n_fft=args.n_fft,
            offset=model.offset)

        X_train, y_train = dataset.augment(
            X_train, y_train,
            reduction_rate=args.reduction_rate,
            reduction_mask=reduction_mask,
            mixup_rate=args.mixup_rate,
            mixup_alpha=args.mixup_alpha)

        print('# epoch', epoch)
        for inner_epoch in range(args.inner_epoch):
            print('  * inner epoch {}'.format(inner_epoch))

            train_loss = train_inner_epoch(
                X_train, y_train,
                model=model,
                device=device,
                optimizer=optimizer,
                batchsize=args.batchsize)

            val_loss = val_inner_epoch(val_dataloader, model, device)

            print('    * training loss = {:.6f}, validation loss = {:.6f}'
                  .format(train_loss, val_loss))

            scheduler.step(val_loss)

            if val_loss < best_loss:
                best_loss = val_loss
                print('    * best validation loss')
                model_path = 'models/model_iter{}.pth'.format(epoch)
                torch.save(model.state_dict(), model_path)

            log.append([train_loss, val_loss])
            with open('log_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
                json.dump(log, f, ensure_ascii=False)

        del X_train, y_train
        gc.collect()
def main():
    p = argparse.ArgumentParser()
    p.add_argument('--gpu', '-g', type=int, default=-1)
    p.add_argument(
        '--model',
        '-m',
        type=str,
        default=
        '/content/drive/My Drive/vocal-remover/models/MultiGenreModelNP.pth')
    p.add_argument('--input', '-i', required=True)
    p.add_argument('--sr', '-r', type=int, default=44100)
    p.add_argument('--hop_length', '-l', type=int, default=1024)
    p.add_argument('--window_size', '-w', type=int, default=512)
    p.add_argument('--out_mask', '-M', action='store_true')
    p.add_argument('--postprocess', '-p', action='store_true')
    args = p.parse_args()

    print('loading model...', end=' ')
    device = torch.device('cpu')
    model = nets.CascadedASPPNet()
    model.load_state_dict(torch.load(args.model, map_location=device))
    if torch.cuda.is_available() and args.gpu >= 0:
        device = torch.device('cuda:{}'.format(args.gpu))
        model.to(device)
    print('done')

    print('loading wave source...', end=' ')
    X, sr = librosa.load(args.input,
                         args.sr,
                         False,
                         dtype=np.float32,
                         res_type='kaiser_fast')
    print('done')

    print('stft of wave source...', end=' ')
    X = spec_utils.calc_spec(X, args.hop_length)
    X, phase = np.abs(X), np.exp(1.j * np.angle(X))
    coeff = X.max()
    X /= coeff
    print('done')

    offset = model.offset
    l, r, roi_size = dataset.make_padding(X.shape[2], args.window_size, offset)
    X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant')
    X_roll = np.roll(X_pad, roi_size // 2, axis=2)

    model.eval()
    with torch.no_grad():
        masks = []
        masks_roll = []
        for i in tqdm(range(int(np.ceil(X.shape[2] / roi_size)))):
            start = i * roi_size
            X_window = torch.from_numpy(
                np.asarray([
                    X_pad[:, :, start:start + args.window_size],
                    X_roll[:, :, start:start + args.window_size]
                ])).to(device)
            pred = model.predict(X_window)
            pred = pred.detach().cpu().numpy()
            masks.append(pred[0])
            masks_roll.append(pred[1])

        mask = np.concatenate(masks, axis=2)[:, :, :X.shape[2]]
        mask_roll = np.concatenate(masks_roll, axis=2)[:, :, :X.shape[2]]
        mask = (mask + np.roll(mask_roll, -roi_size // 2, axis=2)) / 2

    if args.postprocess:
        vocal = X * (1 - mask) * coeff
        mask = spec_utils.mask_uninformative(mask, vocal)

    inst = X * mask * coeff
    vocal = X * (1 - mask) * coeff

    basename = os.path.splitext(os.path.basename(args.input))[0]

    print('inverse stft of instruments...', end=' ')
    wav = spec_utils.spec_to_wav(inst, phase, args.hop_length)
    print('done')
    sf.write('{}_Instruments.wav'.format(basename), wav.T, sr)

    print('inverse stft of vocals...', end=' ')
    wav = spec_utils.spec_to_wav(vocal, phase, args.hop_length)
    print('done')
    sf.write('{}_Vocals.wav'.format(basename), wav.T, sr)

    if args.out_mask:
        norm_mask = np.uint8((1 - mask) * 255).transpose(1, 2, 0)
        norm_mask = np.concatenate(
            [np.max(norm_mask, axis=2, keepdims=True), norm_mask],
            axis=2)[::-1]
        _, bin_mask = cv2.imencode('.png', norm_mask)
        with open('{}_Mask.png'.format(basename), mode='wb') as f:
            bin_mask.tofile(f)
Exemple #3
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('--gpu', '-g', type=int, default=-1)
    p.add_argument('--model', '-m', type=str, default='models/baseline.pth')
    p.add_argument('--input', '-i', required=True)
    p.add_argument('--sr', '-r', type=int, default=44100)
    p.add_argument('--hop_length', '-l', type=int, default=1024)
    p.add_argument('--window_size', '-w', type=int, default=512)
    p.add_argument('--out_mask', '-M', action='store_true')
    p.add_argument('--postprocess', '-p', action='store_true')
    args = p.parse_args()

    print('loading model...', end=' ')
    device = torch.device('cpu')
    model = nets.CascadedASPPNet()
    model.load_state_dict(torch.load(args.model, map_location=device))
    if torch.cuda.is_available() and args.gpu >= 0:
        device = torch.device('cuda:{}'.format(args.gpu))
        model.to(device)
    print('done')

    print('loading wave source...', end=' ')
    X, sr = librosa.load(args.input,
                         args.sr,
                         False,
                         dtype=np.float32,
                         res_type='kaiser_fast')
    print('done')

    print('wave source stft...', end=' ')
    X = spec_utils.calc_spec(X, args.hop_length)
    X, phase = np.abs(X), np.exp(1.j * np.angle(X))
    coeff = X.max()
    X /= coeff
    print('done')

    offset = model.offset
    l, r, roi_size = dataset.make_padding(X.shape[2], args.window_size, offset)
    X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant')

    masks = []
    model.eval()
    with torch.no_grad():
        for j in tqdm(range(int(np.ceil(X.shape[2] / roi_size)))):
            start = j * roi_size
            X_window = X_pad[None, :, :, start:start + args.window_size]
            pred = model.predict(torch.from_numpy(X_window).to(device))
            pred = pred.detach().cpu().numpy()
            masks.append(pred[0])

    mask = np.concatenate(masks, axis=2)[:, :, :X.shape[2]]
    if args.postprocess:
        vocal_pred = X * (1 - mask) * coeff
        mask = spec_utils.mask_uninformative(mask, vocal_pred)
    inst_pred = X * mask * coeff
    vocal_pred = X * (1 - mask) * coeff

    if args.out_mask:
        norm_mask = np.uint8((1 - mask) * 255)
        canvas = np.zeros((norm_mask.shape[1], norm_mask.shape[2], 3))
        canvas[:, :, 1] = norm_mask[0]
        canvas[:, :, 2] = norm_mask[1]
        canvas[:, :, 0] = np.max(norm_mask, axis=0)
        cv2.imwrite('mask.png', canvas[::-1])

    basename = os.path.splitext(os.path.basename(args.input))[0]

    print('instrumental inverse stft...', end=' ')
    wav = spec_utils.spec_to_wav(inst_pred, phase, args.hop_length)
    print('done')
    sf.write('{}_Instrumental.wav'.format(basename), wav.T, sr)

    print('vocal inverse stft...', end=' ')
    wav = spec_utils.spec_to_wav(vocal_pred, phase, args.hop_length)
    print('done')
    sf.write('{}_Vocal.wav'.format(basename), wav.T, sr)
Exemple #4
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('--gpu', '-g', type=int, default=-1)
    p.add_argument('--seed', '-s', type=int, default=2019)
    p.add_argument('--sr', '-r', type=int, default=44100)
    p.add_argument('--hop_length', '-l', type=int, default=1024)
    p.add_argument('--mixtures', '-m', required=True)
    p.add_argument('--instruments', '-i', required=True)
    p.add_argument('--learning_rate', type=float, default=0.001)
    p.add_argument('--lr_min', type=float, default=0.0001)
    p.add_argument('--lr_decay_factor', type=float, default=0.9)
    p.add_argument('--lr_decay_patience', type=int, default=6)
    p.add_argument('--batchsize', '-B', type=int, default=4)
    p.add_argument('--cropsize', '-c', type=int, default=256)
    p.add_argument('--val_rate', '-v', type=float, default=0.1)
    p.add_argument('--val_filelist', '-V', type=str, default=None)
    p.add_argument('--val_batchsize', '-b', type=int, default=4)
    p.add_argument('--val_cropsize', '-C', type=int, default=512)
    p.add_argument('--patches', '-p', type=int, default=16)
    p.add_argument('--epoch', '-E', type=int, default=80)
    p.add_argument('--inner_epoch', '-e', type=int, default=4)
    p.add_argument('--oracle_rate', '-O', type=float, default=0)
    p.add_argument('--oracle_drop_rate', '-o', type=float, default=0.5)
    p.add_argument('--mixup_rate', '-M', type=float, default=0.0)
    p.add_argument('--mixup_alpha', '-a', type=float, default=1.0)
    p.add_argument('--pretrained_model', '-P', type=str, default=None)
    p.add_argument('--debug', '-d', action='store_true')
    args = p.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    timestamp = dt.now().strftime('%Y%m%d%H%M%S')

    model = nets.CascadedASPPNet()
    if args.pretrained_model is not None:
        model.load_state_dict(torch.load(args.pretrained_model))
    if args.gpu >= 0:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_factor,
        patience=args.lr_decay_patience,
        min_lr=args.lr_min,
        verbose=True)

    train_filelist, val_filelist = train_val_split(
        mix_dir=args.mixtures,
        inst_dir=args.instruments,
        val_rate=args.val_rate,
        val_filelist_json=args.val_filelist)

    if args.debug:
        print('### DEBUG MODE')
        train_filelist = train_filelist[:1]
        val_filelist = val_filelist[:1]

    with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
        json.dump(val_filelist, f, ensure_ascii=False)

    for i, (X_fname, y_fname) in enumerate(val_filelist):
        print(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))

    val_dataset = dataset.make_validation_set(filelist=val_filelist,
                                              cropsize=args.val_cropsize,
                                              sr=args.sr,
                                              hop_length=args.hop_length,
                                              offset=model.offset)
    val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                 batch_size=args.val_batchsize,
                                                 shuffle=False,
                                                 num_workers=4)

    log = []
    oracle_X = None
    oracle_y = None
    best_loss = np.inf
    for epoch in range(args.epoch):
        X_train, y_train = dataset.make_training_set(train_filelist,
                                                     args.cropsize,
                                                     args.patches, args.sr,
                                                     args.hop_length,
                                                     model.offset)

        X_train, y_train = dataset.mixup_generator(X_train, y_train,
                                                   args.mixup_rate,
                                                   args.mixup_alpha)

        if oracle_X is not None and oracle_y is not None:
            perm = np.random.permutation(len(X_train))[:len(oracle_X)]
            X_train[perm] = oracle_X
            y_train[perm] = oracle_y

        print('# epoch', epoch)
        for inner_epoch in range(args.inner_epoch):
            print('  * inner epoch {}'.format(inner_epoch))
            train_loss, instance_loss = train_inner_epoch(
                X_train, y_train, model, optimizer, args.batchsize)
            val_loss = val_inner_epoch(val_dataloader, model)

            print('    * training loss = {:.6f}, validation loss = {:.6f}'.
                  format(train_loss * 1000, val_loss * 1000))

            scheduler.step(val_loss)

            if val_loss < best_loss:
                best_loss = val_loss
                print('    * best validation loss')
                model_path = 'models/model_iter{}.pth'.format(epoch)
                torch.save(model.state_dict(), model_path)

            log.append([train_loss, val_loss])
            with open('log_{}.json'.format(timestamp), 'w',
                      encoding='utf8') as f:
                json.dump(log, f, ensure_ascii=False)

        if args.oracle_rate > 0:
            oracle_X, oracle_y, idx = dataset.get_oracle_data(
                X_train, y_train, instance_loss, args.oracle_rate,
                args.oracle_drop_rate)
            print('  * oracle loss = {:.6f}'.format(instance_loss[idx].mean() *
                                                    1000))

        del X_train, y_train
        gc.collect()
Exemple #5
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('--gpu', '-g', type=int, default=-1)
    p.add_argument('--pretrained_model',
                   '-P',
                   type=str,
                   default='models/baseline.pth')
    p.add_argument('--input', '-i', required=True)
    p.add_argument('--sr', '-r', type=int, default=44100)
    p.add_argument('--n_fft', '-f', type=int, default=2048)
    p.add_argument('--hop_length', '-l', type=int, default=1024)
    p.add_argument('--window_size', '-w', type=int, default=512)
    p.add_argument('--output_image', '-I', action='store_true')
    p.add_argument('--postprocess', '-p', action='store_true')
    p.add_argument('--tta', '-t', action='store_true')
    args = p.parse_args()

    print('loading model...', end=' ')
    device = torch.device('cpu')
    model = nets.CascadedASPPNet(args.n_fft)
    model.load_state_dict(
        torch.load(args.pretrained_model, map_location=device))
    if torch.cuda.is_available() and args.gpu >= 0:
        device = torch.device('cuda:{}'.format(args.gpu))
        model.to(device)
    print('done')

    print('loading wave source...', end=' ')
    X, sr = librosa.load(args.input,
                         args.sr,
                         False,
                         dtype=np.float32,
                         res_type='kaiser_fast')
    basename = os.path.splitext(os.path.basename(args.input))[0]
    print('done')

    if X.ndim == 1:
        X = np.asarray([X, X])

    print('stft of wave source...', end=' ')
    X = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft)
    print('done')

    vr = VocalRemover(model, device, args.window_size)

    if args.tta:
        pred, X_mag, X_phase = vr.inference_tta(X)
    else:
        pred, X_mag, X_phase = vr.inference(X)

    if args.postprocess:
        print('post processing...', end=' ')
        pred_inv = np.clip(X_mag - pred, 0, np.inf)
        pred = spec_utils.mask_silence(pred, pred_inv)
        print('done')

    print('inverse stft of instruments...', end=' ')
    y_spec = pred * X_phase
    wave = spec_utils.spectrogram_to_wave(y_spec, hop_length=args.hop_length)
    print('done')
    sf.write('{}_Instruments.wav'.format(basename), wave.T, sr)

    print('inverse stft of vocals...', end=' ')
    v_spec = np.clip(X_mag - pred, 0, np.inf) * X_phase
    wave = spec_utils.spectrogram_to_wave(v_spec, hop_length=args.hop_length)
    print('done')
    sf.write('{}_Vocals.wav'.format(basename), wave.T, sr)

    if args.output_image:
        with open('{}_Instruments.jpg'.format(basename), mode='wb') as f:
            image = spec_utils.spectrogram_to_image(y_spec)
            _, bin_image = cv2.imencode('.jpg', image)
            bin_image.tofile(f)
        with open('{}_Vocals.jpg'.format(basename), mode='wb') as f:
            image = spec_utils.spectrogram_to_image(v_spec)
            _, bin_image = cv2.imencode('.jpg', image)
            bin_image.tofile(f)
Exemple #6
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--gpu", "-g", type=int, default=-1)
    p.add_argument("--seed", "-s", type=int, default=2019)
    p.add_argument("--sr", "-r", type=int, default=44100)
    p.add_argument("--hop_length", "-l", type=int, default=1024)
    p.add_argument("--mixtures", "-m", required=True)
    p.add_argument("--instruments", "-i", required=True)
    p.add_argument("--learning_rate", type=float, default=0.001)
    p.add_argument("--lr_min", type=float, default=0.0001)
    p.add_argument("--lr_decay_factor", type=float, default=0.9)
    p.add_argument("--lr_decay_patience", type=int, default=6)
    p.add_argument("--batchsize", "-B", type=int, default=4)
    p.add_argument("--cropsize", "-c", type=int, default=256)
    p.add_argument("--val_rate", "-v", type=float, default=0.1)
    p.add_argument("--val_filelist", "-V", type=str, default=None)
    p.add_argument("--val_batchsize", "-b", type=int, default=4)
    p.add_argument("--val_cropsize", "-C", type=int, default=512)
    p.add_argument("--patches", "-p", type=int, default=16)
    p.add_argument("--epoch", "-E", type=int, default=80)
    p.add_argument("--inner_epoch", "-e", type=int, default=4)
    p.add_argument("--oracle_rate", "-O", type=float, default=0)
    p.add_argument("--oracle_drop_rate", "-o", type=float, default=0.5)
    p.add_argument("--mixup_rate", "-M", type=float, default=0.0)
    p.add_argument("--mixup_alpha", "-a", type=float, default=1.0)
    p.add_argument("--pretrained_model", "-P", type=str, default=None)
    p.add_argument("--debug", "-d", action="store_true")
    args = p.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    timestamp = dt.now().strftime("%Y%m%d%H%M%S")

    model = nets.CascadedASPPNet()
    if args.pretrained_model is not None:
        model.load_state_dict(torch.load(args.pretrained_model))
    if args.gpu >= 0:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_factor,
        patience=args.lr_decay_patience,
        min_lr=args.lr_min,
        verbose=True,
    )

    train_filelist, val_filelist = train_val_split(
        mix_dir=args.mixtures,
        inst_dir=args.instruments,
        val_rate=args.val_rate,
        val_filelist_json=args.val_filelist,
    )

    if args.debug:
        print("### DEBUG MODE")
        train_filelist = train_filelist[:1]
        val_filelist = val_filelist[:1]

    with open("val_{}.json".format(timestamp), "w", encoding="utf8") as f:
        json.dump(val_filelist, f, ensure_ascii=False)

    for i, (X_fname, y_fname) in enumerate(val_filelist):
        print(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))

    val_dataset = dataset.make_validation_set(
        filelist=val_filelist,
        cropsize=args.val_cropsize,
        sr=args.sr,
        hop_length=args.hop_length,
        offset=model.offset,
    )
    val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                 batch_size=args.val_batchsize,
                                                 shuffle=False,
                                                 num_workers=4)

    log = []
    oracle_X = None
    oracle_y = None
    best_loss = np.inf
    for epoch in range(args.epoch):
        X_train, y_train = dataset.make_training_set(
            train_filelist,
            args.cropsize,
            args.patches,
            args.sr,
            args.hop_length,
            model.offset,
        )

        X_train, y_train = dataset.mixup_generator(X_train, y_train,
                                                   args.mixup_rate,
                                                   args.mixup_alpha)

        if oracle_X is not None and oracle_y is not None:
            perm = np.random.permutation(len(X_train))[:len(oracle_X)]
            X_train[perm] = oracle_X
            y_train[perm] = oracle_y

        print("# epoch", epoch)
        for inner_epoch in range(args.inner_epoch):
            print("  * inner epoch {}".format(inner_epoch))
            train_loss, instance_loss = train_inner_epoch(
                X_train, y_train, model, optimizer, args.batchsize)
            val_loss = val_inner_epoch(val_dataloader, model)

            print("    * training loss = {:.6f}, validation loss = {:.6f}".
                  format(train_loss * 1000, val_loss * 1000))

            scheduler.step(val_loss)

            if val_loss < best_loss:
                best_loss = val_loss
                print("    * best validation loss")
                model_path = "models/model_iter{}.pth".format(epoch)
                torch.save(model.state_dict(), model_path)

            log.append([train_loss, val_loss])
            with open("log_{}.json".format(timestamp), "w",
                      encoding="utf8") as f:
                json.dump(log, f, ensure_ascii=False)

        if args.oracle_rate > 0:
            oracle_X, oracle_y, idx = dataset.get_oracle_data(
                X_train, y_train, instance_loss, args.oracle_rate,
                args.oracle_drop_rate)
            print("  * oracle loss = {:.6f}".format(instance_loss[idx].mean() *
                                                    1000))

        del X_train, y_train
        gc.collect()