Beispiel #1
0
def main(args):
    # MODEL
    num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
                   [args.features*2**i for i in range(0, args.levels)]
    target_outputs = int(args.output_size * args.sr)
    model = Waveunet(args.channels,
                     num_features,
                     args.channels,
                     args.instruments,
                     kernel_size=args.kernel_size,
                     target_output_size=target_outputs,
                     depth=args.depth,
                     strides=args.strides,
                     conv_type=args.conv_type,
                     res=args.res,
                     separate=args.separate)

    if args.cuda:
        model = utils.DataParallel(model)
        print("move model to gpu")
        model.cuda()

    print("Loading model from checkpoint " + str(args.load_model))
    state = utils.load_model(model, None, args.load_model)
    print('Step', state['step'])

    preds = predict_song(args, args.input, model)

    output_folder = os.path.dirname(
        args.input) if args.output is None else args.output
    for inst in preds.keys():
        utils.write_wav(
            os.path.join(output_folder,
                         os.path.basename(args.input) + "_" + inst + ".wav"),
            preds[inst], args.sr)
Beispiel #2
0
def main(args):
    # MODEL
    num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
                   [args.features*2**i for i in range(0, args.levels)]
    target_outputs = int(args.output_size * args.sr)
    model = Waveunet(args.channels,
                     num_features,
                     args.channels,
                     args.instruments,
                     kernel_size=args.kernel_size,
                     target_output_size=target_outputs,
                     depth=args.depth,
                     strides=args.strides,
                     conv_type=args.conv_type,
                     res=args.res,
                     separate=args.separate)

    if args.cuda:
        model = utils.DataParallel(model)
        print("move model to gpu")
        model.cuda()

    print("Loading model from checkpoint " + str(args.load_model))
    state = utils.load_model(model, None, args.load_model, args.cuda)
    print('Step', state['step'])

    output1 = tsne_predict_song(args, args.input1, model,
                                args.tsne_inst).squeeze(0)
    output2 = tsne_predict_song(args, args.input2, model,
                                args.tsne_inst).squeeze(0)
    output3 = tsne_predict_song(args, args.input3, model,
                                args.tsne_inst).squeeze(0)

    output = np.concatenate((output1, output2, output3), axis=0)

    print("debug1: " + str(output1.shape) + ", " + str(output.shape))

    tsne_output = TSNE(learning_rate=100).fit_transform(output)

    labels = ['#ff0000' for i in range(len(output1))] + [
        '#00ff00' for i in range(len(output2))
    ] + ['#0000ff' for i in range(len(output3))]
    plt.scatter(tsne_output[:, 0], tsne_output[:, 1], c=labels, s=5)
    #plt.show()
    plt.savefig(args.output)
Beispiel #3
0
def main(args):
    #torch.backends.cudnn.benchmark=True # This makes dilated conv much faster for CuDNN 7.5

    # MODEL
    num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
                   [args.features*2**i for i in range(0, args.levels)]
    target_outputs = int(args.output_size * args.sr)
    model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size,
                     target_output_size=target_outputs, depth=args.depth, strides=args.strides,
                     conv_type=args.conv_type, res=args.res, separate=args.separate)

    if args.cuda:
        model = utils.DataParallel(model)
        print("move model to gpu")
        model.cuda()

    print('model: ', model)
    print('parameter count: ', str(sum(p.numel() for p in model.parameters())))

    writer = SummaryWriter(args.log_dir)

    ### DATASET
    musdb = get_musdb_folds(args.dataset_dir)
    # If not data augmentation, at least crop targets to fit model output shape
    crop_func = partial(crop, shapes=model.shapes)
    # Data augmentation function for training
    augment_func = partial(random_amplify, shapes=model.shapes, min=0.7, max=1.0)
    train_data = SeparationDataset(musdb, "train", args.instruments, args.sr, args.channels, model.shapes, True, args.hdf_dir, audio_transform=augment_func)
    val_data = SeparationDataset(musdb, "val", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func)
    test_data = SeparationDataset(musdb, "test", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func)

    dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=utils.worker_init_fn)

    ##### TRAINING ####

    # Set up the loss function
    if args.loss == "L1":
        criterion = nn.L1Loss()
    elif args.loss == "L2":
        criterion = nn.MSELoss()
    else:
        raise NotImplementedError("Couldn't find this loss!")

    # Set up optimiser
    optimizer = Adam(params=model.parameters(), lr=args.lr)

    # Set up training state dict that will also be saved into checkpoints
    state = {"step" : 0,
             "worse_epochs" : 0,
             "epochs" : 0,
             "best_loss" : np.Inf}

    # LOAD MODEL CHECKPOINT IF DESIRED
    if args.load_model is not None:
        print("Continuing training full model from checkpoint " + str(args.load_model))
        state = utils.load_model(model, optimizer, args.load_model)

    print('TRAINING START')
    while state["worse_epochs"] < args.patience:
        print("Training one epoch from iteration " + str(state["step"]))
        avg_time = 0.
        model.train()
        with tqdm(total=len(train_data) // args.batch_size) as pbar:
            np.random.seed()
            for example_num, (x, targets) in enumerate(dataloader):
                if args.cuda:
                    x = x.cuda()
                    for k in list(targets.keys()):
                        targets[k] = targets[k].cuda()

                t = time.time()

                # Set LR for this iteration
                utils.set_cyclic_lr(optimizer, example_num, len(train_data) // args.batch_size, args.cycles, args.min_lr, args.lr)
                writer.add_scalar("lr", utils.get_lr(optimizer), state["step"])

                # Compute loss for each instrument/model
                optimizer.zero_grad()
                outputs, avg_loss = utils.compute_loss(model, x, targets, criterion, compute_grad=True)

                optimizer.step()

                state["step"] += 1

                t = time.time() - t
                avg_time += (1. / float(example_num + 1)) * (t - avg_time)

                writer.add_scalar("train_loss", avg_loss, state["step"])

                if example_num % args.example_freq == 0:
                    input_centre = torch.mean(x[0, :, model.shapes["output_start_frame"]:model.shapes["output_end_frame"]], 0) # Stereo not supported for logs yet
                    writer.add_audio("input", input_centre, state["step"], sample_rate=args.sr)

                    for inst in outputs.keys():
                        writer.add_audio(inst + "_pred", torch.mean(outputs[inst][0], 0), state["step"], sample_rate=args.sr)
                        writer.add_audio(inst + "_target", torch.mean(targets[inst][0], 0), state["step"], sample_rate=args.sr)

                pbar.update(1)

        # VALIDATE
        val_loss = validate(args, model, criterion, val_data)
        print("VALIDATION FINISHED: LOSS: " + str(val_loss))
        writer.add_scalar("val_loss", val_loss, state["step"])

        # EARLY STOPPING CHECK
        checkpoint_path = os.path.join(args.checkpoint_dir, "checkpoint_" + str(state["step"]))
        if val_loss >= state["best_loss"]:
            state["worse_epochs"] += 1
        else:
            print("MODEL IMPROVED ON VALIDATION SET!")
            state["worse_epochs"] = 0
            state["best_loss"] = val_loss
            state["best_checkpoint"] = checkpoint_path

        # CHECKPOINT
        print("Saving model...")
        utils.save_model(model, optimizer, state, checkpoint_path)

        state["epochs"] += 1

    #### TESTING ####
    # Test loss
    print("TESTING")

    # Load best model based on validation loss
    state = utils.load_model(model, None, state["best_checkpoint"])
    test_loss = validate(args, model, criterion, test_data)
    print("TEST FINISHED: LOSS: " + str(test_loss))
    writer.add_scalar("test_loss", test_loss, state["step"])

    # Mir_eval metrics
    test_metrics = evaluate(args, musdb["test"], model, args.instruments)

    # Dump all metrics results into pickle file for later analysis if needed
    with open(os.path.join(args.checkpoint_dir, "results.pkl"), "wb") as f:
        pickle.dump(test_metrics, f)

    # Write most important metrics into Tensorboard log
    avg_SDRs = {inst : np.mean([np.nanmean(song[inst]["SDR"]) for song in test_metrics]) for inst in args.instruments}
    avg_SIRs = {inst : np.mean([np.nanmean(song[inst]["SIR"]) for song in test_metrics]) for inst in args.instruments}
    for inst in args.instruments:
        writer.add_scalar("test_SDR_" + inst, avg_SDRs[inst], state["step"])
        writer.add_scalar("test_SIR_" + inst, avg_SIRs[inst], state["step"])
    overall_SDR = np.mean([v for v in avg_SDRs.values()])
    writer.add_scalar("test_SDR", overall_SDR)
    print("SDR: " + str(overall_SDR))

    writer.close()
Beispiel #4
0
def main(args):
    os.environ['KMP_WARNINGS'] = '0'
    torch.cuda.manual_seed_all(1)
    np.random.seed(0)
    print(args.model_name)
    print(args.alpha)
    # filter array
    num_features = [
        args.features * i
        for i in range(1, args.levels + 2 + args.levels_without_sample)
    ]

    # 確定 輸出大小
    target_outputs = int(args.output_size * args.sr)
    # 訓練才保存模型設定參數

    # 設定teacher and student and student_for_backward 超參數

    student_KD = Waveunet(args.channels,
                          num_features,
                          args.channels,
                          levels=args.levels,
                          encoder_kernel_size=args.encoder_kernel_size,
                          decoder_kernel_size=args.decoder_kernel_size,
                          target_output_size=target_outputs,
                          depth=args.depth,
                          strides=args.strides,
                          conv_type=args.conv_type,
                          res=args.res)
    KD_optimizer = Adam(params=student_KD.parameters(), lr=args.lr)
    print(25 * '=' + 'model setting' + 25 * '=')
    print('student_KD: ', student_KD.shapes)
    if args.cuda:
        student_KD = utils.DataParallel(student_KD)
        print("move student_KD to gpu\n")
        student_KD.cuda()

    state = {"step": 0, "worse_epochs": 0, "epochs": 0, "best_pesq": -np.Inf}
    if args.load_model is not None:
        print("Continuing full model from checkpoint " + str(args.load_model))
        state = utils.load_model(student_KD, KD_optimizer, args.load_model,
                                 args.cuda)
    dataset = get_folds(args.dataset_dir, args.outside_test)
    log_dir, checkpoint_dir, result_dir = utils.mkdir_and_get_path(args)
    # print(model)
    if args.test is False:
        writer = SummaryWriter(log_dir)
        # set hypeparameter
        # printing hypeparameters info
        print(25 * '=' + 'printing hypeparameters info' + 25 * '=')

        with open(os.path.join(log_dir, 'config.json'), 'w') as f:
            json.dump(args.__dict__, f, indent=5)
        print('saving commandline_args')
        student_size = sum(p.numel() for p in student_KD.parameters())
        print('student_parameter count: ', str(student_size))
        if args.teacher_model is not None:
            teacher_num_features = [
                24 * i
                for i in range(1, args.levels + 2 + args.levels_without_sample)
            ]
            teacher_model = Waveunet(
                args.channels,
                teacher_num_features,
                args.channels,
                levels=args.levels,
                encoder_kernel_size=args.encoder_kernel_size,
                decoder_kernel_size=args.decoder_kernel_size,
                target_output_size=target_outputs,
                depth=args.depth,
                strides=args.strides,
                conv_type=args.conv_type,
                res=args.res)

            if args.cuda:
                teacher_model = utils.DataParallel(teacher_model)
                teacher_model.cuda()
                # print("move teacher to gpu\n")
            student_size = sum(p.numel() for p in student_KD.parameters())
            teacher_size = sum(p.numel() for p in teacher_model.parameters())
            print('student_parameter count: ', str(student_size))
            print('teacher_model_parameter count: ', str(teacher_size))
            print(f'compression raito :{100*(student_size/teacher_size)}%')
            if args.teacher_model is not None:
                print("load teacher model" + str(args.teacher_model))
                _ = utils.load_model(teacher_model, None, args.teacher_model,
                                     args.cuda)
                teacher_model.eval()

        # If not data augmentation, at least crop targets to fit model output shape
        crop_func = partial(crop, shapes=student_KD.shapes)
        ### DATASET
        train_data = SeparationDataset(dataset,
                                       "train",
                                       args.sr,
                                       args.channels,
                                       student_KD.shapes,
                                       False,
                                       args.hdf_dir,
                                       audio_transform=crop_func)
        val_data = SeparationDataset(dataset,
                                     "test",
                                     args.sr,
                                     args.channels,
                                     student_KD.shapes,
                                     False,
                                     args.hdf_dir,
                                     audio_transform=crop_func)
        dataloader = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            worker_init_fn=utils.worker_init_fn,
            pin_memory=True)

        # Set up the loss function
        if args.loss == "L1":
            criterion = nn.L1Loss()
        elif args.loss == "L2":
            criterion = nn.MSELoss()
        else:
            raise NotImplementedError("Couldn't find this loss!")
        My_criterion = customLoss()

        ### TRAINING START
        print('TRAINING START')
        batch_num = (len(train_data) // args.batch_size)
        while state["epochs"] < 100:
            #     if state["epochs"]<10:
            #         args.alpha=1
            #     else:
            #         args.alpha=0
            # print('fix alpha:',args.alpha)
            memory_alpha = []
            print("epoch:" + str(state["epochs"]))
            student_KD.train()
            # monitor_value
            avg_origin_loss = 0
            with tqdm(total=len(dataloader)) as pbar:
                for example_num, (x, targets) in enumerate(dataloader):
                    if args.cuda:
                        x = x.cuda()
                        targets = targets.cuda()
                    if args.teacher_model is not None:
                        # Set LR for this iteration
                        #print('base_model from KD')

                        utils.set_cyclic_lr(KD_optimizer, example_num,
                                            len(train_data) // args.batch_size,
                                            args.cycles, args.min_lr, args.lr)
                        _, avg_student_KD_loss = utils.compute_loss(
                            student_KD,
                            x,
                            targets,
                            criterion,
                            compute_grad=False)

                        KD_optimizer.zero_grad()
                        KD_outputs, KD_hard_loss, KD_loss, KD_soft_loss = utils.KD_compute_loss(
                            student_KD,
                            teacher_model,
                            x,
                            targets,
                            My_criterion,
                            alpha=args.alpha,
                            compute_grad=True,
                            KD_method=args.KD_method)
                        KD_optimizer.step()

                        # calculate backwarded model MSE

                        avg_origin_loss += avg_student_KD_loss / batch_num

                        # add to tensorboard
                        writer.add_scalar("KD_loss", KD_loss, state["step"])
                        writer.add_scalar("KD_hard_loss", KD_hard_loss,
                                          state["step"])
                        writer.add_scalar("KD_soft_loss", KD_soft_loss,
                                          state["step"])
                    else:  # no KD training
                        utils.set_cyclic_lr(KD_optimizer, example_num,
                                            len(train_data) // args.batch_size,
                                            args.cycles, args.min_lr, args.lr)
                        KD_optimizer.zero_grad()
                        KD_outputs, KD_hard_loss = utils.compute_loss(
                            student_KD,
                            x,
                            targets,
                            nn.MSELoss(),
                            compute_grad=True)
                        KD_optimizer.step()
                        avg_origin_loss += KD_hard_loss / batch_num
                        writer.add_scalar("student_KD_loss", KD_hard_loss,
                                          state["step"])

                    ### save wav ####
                    if example_num % args.example_freq == 0:
                        input_centre = torch.mean(
                            x[0, :, student_KD.shapes["output_start_frame"]:
                              student_KD.shapes["output_end_frame"]],
                            0)  # Stereo not supported for logs yet

                        writer.add_audio("input:",
                                         input_centre,
                                         state["step"],
                                         sample_rate=args.sr)
                        writer.add_audio("pred:",
                                         torch.mean(KD_outputs[0], 0),
                                         state["step"],
                                         sample_rate=args.sr)
                        writer.add_audio("target",
                                         torch.mean(targets[0], 0),
                                         state["step"],
                                         sample_rate=args.sr)

                    state["step"] += 1
                    pbar.update(1)
            # VALIDATE
            val_loss, val_metrics = validate(args, student_KD, criterion,
                                             val_data)
            print("ori VALIDATION FINISHED: LOSS: " + str(val_loss))

            writer.add_scalar("avg_origin_loss", avg_origin_loss,
                              state["epochs"])
            writer.add_scalar("val_enhance_pesq", val_metrics[0],
                              state["epochs"])
            writer.add_scalar("val_improve_pesq", val_metrics[1],
                              state["epochs"])
            writer.add_scalar("val_enhance_stoi", val_metrics[2],
                              state["epochs"])
            writer.add_scalar("val_improve_stoi", val_metrics[3],
                              state["epochs"])
            writer.add_scalar("val_enhance_SISDR", val_metrics[4],
                              state["epochs"])
            writer.add_scalar("val_improve_SISDR", val_metrics[5],
                              state["epochs"])
            # writer.add_scalar("val_COPY_pesq",val_metrics_copy[0], state["epochs"])
            writer.add_scalar("val_loss", val_loss, state["epochs"])

            # Set up training state dict that will also be saved into checkpoints
            checkpoint_path = os.path.join(
                checkpoint_dir, "checkpoint_" + str(state["epochs"]))
            if val_metrics[0] < state["best_pesq"]:
                state["worse_epochs"] += 1
            else:
                print("MODEL IMPROVED ON VALIDATION SET!")
                state["worse_epochs"] = 0
                state["best_pesq"] = val_metrics[0]
                state["best_checkpoint"] = checkpoint_path

            # CHECKPOINT
            print("Saving model...")
            utils.save_model(student_KD, KD_optimizer, state, checkpoint_path)
            print('dump alpha_memory')
            with open(os.path.join(log_dir, 'alpha_' + str(state["epochs"])),
                      "wb") as fp:  #Pickling
                pickle.dump(memory_alpha, fp)
            state["epochs"] += 1
        writer.close()
        info = args.model_name
        path = os.path.join(result_dir, info)
    else:
        PATH = args.load_model.split("/")
        info = PATH[-3] + "_" + PATH[-1]
        if (args.outside_test == True):
            info += "_outside_test"
        print(info)
        path = os.path.join(result_dir, info)

    #### TESTING ####
    # Test loss
    print("TESTING")
    # eval metrics
    _ = utils.load_model(student_KD, KD_optimizer, state["best_checkpoint"],
                         args.cuda)
    test_metrics = evaluate(args, dataset["test"], student_KD)
    test_pesq = test_metrics['pesq']
    test_stoi = test_metrics['stoi']
    test_SISDR = test_metrics['SISDR']
    test_noise = test_metrics['noise']

    if not os.path.exists(path):
        os.makedirs(path)
    utils.save_result(test_pesq, path, "pesq")
    utils.save_result(test_stoi, path, "stoi")
    utils.save_result(test_SISDR, path, "SISDR")
    utils.save_result(test_noise, path, "noise")
               [args.features*2**i for i in range(0, args.levels)]
target_outputs = int(args.output_size * args.sr)
model = Waveunet(args.channels,
                 num_features,
                 args.channels,
                 INSTRUMENTS,
                 kernel_size=args.kernel_size,
                 target_output_size=target_outputs,
                 depth=args.depth,
                 strides=args.strides,
                 conv_type=args.conv_type,
                 res=args.res,
                 separate=args.separate)

if args.cuda:
    model = utils.DataParallel(model)
    print("move model to gpu")
    model.cuda()

print("Loading model from checkpoint " + str(args.load_model))
state = utils.load_model(model, None, args.load_model)

preds = predict_song(args, args.input, model)

output_folder = os.path.dirname(
    args.input) if args.output is None else args.output
for inst in preds.keys():
    utils.write_wav(
        os.path.join(output_folder,
                     os.path.basename(args.input) + "_" + inst + ".wav"),
        preds[inst], args.sr)
Beispiel #6
0
def main(args):
    os.environ['KMP_WARNINGS'] = '0'
    torch.cuda.manual_seed_all(1)
    np.random.seed(0)

    # filter array
    num_features = [
        args.features * i
        for i in range(1, args.levels + 2 + args.levels_without_sample)
    ]

    # 確定 輸出大小
    target_outputs = int(args.output_size * args.sr)
    # 訓練才保存模型設定參數

    # 設定teacher and student and student_for_backward 超參數

    student_KD = Waveunet(args.channels,
                          num_features,
                          args.channels,
                          levels=args.levels,
                          encoder_kernel_size=args.encoder_kernel_size,
                          decoder_kernel_size=args.decoder_kernel_size,
                          target_output_size=target_outputs,
                          depth=args.depth,
                          strides=args.strides,
                          conv_type=args.conv_type,
                          res=args.res)
    KD_optimizer = Adam(params=student_KD.parameters(), lr=args.lr)
    print(25 * '=' + 'model setting' + 25 * '=')
    print('student_KD: ', student_KD.shapes)
    if args.cuda:
        student_KD = utils.DataParallel(student_KD)
        print("move student_KD to gpu\n")
        student_KD.cuda()

    state = {"step": 0, "worse_epochs": 0, "epochs": 0, "best_pesq": -np.Inf}
    if args.load_model is not None:
        print("Continuing full model from checkpoint " + str(args.load_model))
        state = utils.load_model(student_KD, KD_optimizer, args.load_model,
                                 args.cuda)
    dataset = get_folds(args.dataset_dir, args.outside_test)
    log_dir, checkpoint_dir, result_dir = utils.mkdir_and_get_path(args)
    # print(model)
    if args.test is False:
        writer = SummaryWriter(log_dir)
        # set hypeparameter
        # printing hypeparameters info

        with open(os.path.join(log_dir, 'config.json'), 'w') as f:
            json.dump(args.__dict__, f, indent=5)
        print('saving commandline_args')

        if args.teacher_model is not None:
            print(25 * '=' + 'printing hypeparameters info' + 25 * '=')
            print(f'KD_method  = {args.KD_method}')
            teacher_num_features = [
                24 * i
                for i in range(1, args.levels + 2 + args.levels_without_sample)
            ]
            teacher_model = Waveunet(
                args.channels,
                teacher_num_features,
                args.channels,
                levels=args.levels,
                encoder_kernel_size=args.encoder_kernel_size,
                decoder_kernel_size=args.decoder_kernel_size,
                target_output_size=target_outputs,
                depth=args.depth,
                strides=args.strides,
                conv_type=args.conv_type,
                res=args.res)

            student_copy = Waveunet(
                args.channels,
                num_features,
                args.channels,
                levels=args.levels,
                encoder_kernel_size=args.encoder_kernel_size,
                decoder_kernel_size=args.decoder_kernel_size,
                target_output_size=target_outputs,
                depth=args.depth,
                strides=args.strides,
                conv_type=args.conv_type,
                res=args.res)
            copy_optimizer = Adam(params=student_copy.parameters(), lr=args.lr)

            student_copy2 = Waveunet(
                args.channels,
                num_features,
                args.channels,
                levels=args.levels,
                encoder_kernel_size=args.encoder_kernel_size,
                decoder_kernel_size=args.decoder_kernel_size,
                target_output_size=target_outputs,
                depth=args.depth,
                strides=args.strides,
                conv_type=args.conv_type,
                res=args.res)
            copy2_optimizer = Adam(params=student_copy2.parameters(),
                                   lr=args.lr)

            policy_network = RL(n_inputs=2,
                                kernel_size=6,
                                stride=1,
                                conv_type=args.conv_type,
                                pool_size=4)
            PG_optimizer = Adam(params=policy_network.parameters(),
                                lr=args.RL_lr)
            if args.cuda:
                teacher_model = utils.DataParallel(teacher_model)
                policy_network = utils.DataParallel(policy_network)
                student_copy = utils.DataParallel(student_copy)
                student_copy2 = utils.DataParallel(student_copy2)
                # print("move teacher to gpu\n")
                teacher_model.cuda()
                # print("student_copy  to gpu\n")
                student_copy.cuda()
                # print("student_copy2  to gpu\n")
                student_copy2.cuda()
                # print("move policy_network to gpu\n")
                policy_network.cuda()
            student_size = sum(p.numel() for p in student_KD.parameters())
            teacher_size = sum(p.numel() for p in teacher_model.parameters())
            print('student_parameter count: ', str(student_size))
            print('teacher_model_parameter count: ', str(teacher_size))
            print('RL_parameter count: ',
                  str(sum(p.numel() for p in policy_network.parameters())))
            print(f'compression raito :{100*(student_size/teacher_size)}%')
            if args.teacher_model is not None:
                print("load teacher model" + str(args.teacher_model))
                _ = utils.load_model(teacher_model, None, args.teacher_model,
                                     args.cuda)
                teacher_model.eval()

            if args.load_RL_model is not None:
                print("Continuing full RL_model from checkpoint " +
                      str(args.load_RL_model))
                _ = utils.load_model(policy_network, PG_optimizer,
                                     args.load_RL_model, args.cuda)

        # If not data augmentation, at least crop targets to fit model output shape
        crop_func = partial(crop, shapes=student_KD.shapes)
        ### DATASET
        train_data = SeparationDataset(dataset,
                                       "train",
                                       args.sr,
                                       args.channels,
                                       student_KD.shapes,
                                       False,
                                       args.hdf_dir,
                                       audio_transform=crop_func)
        val_data = SeparationDataset(dataset,
                                     "test",
                                     args.sr,
                                     args.channels,
                                     student_KD.shapes,
                                     False,
                                     args.hdf_dir,
                                     audio_transform=crop_func)
        dataloader = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            worker_init_fn=utils.worker_init_fn,
            pin_memory=True)

        # Set up the loss function
        if args.loss == "L1":
            criterion = nn.L1Loss()
        elif args.loss == "L2":
            criterion = nn.MSELoss()
        else:
            raise NotImplementedError("Couldn't find this loss!")
        My_criterion = customLoss()

        ### TRAINING START
        print('TRAINING START')
        if state["epochs"] > 0:
            state["epochs"] = state["epochs"] + 1
        batch_num = (len(train_data) // args.batch_size)

        if args.teacher_model is not None:
            counting = 0
            PG_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer=PG_optimizer, gamma=args.decayRate)

            while counting < state["epochs"]:
                PG_optimizer.zero_grad()
                PG_optimizer.step()
                counting += 1
                PG_lr_scheduler.step()
            # print(f'modify lr RL rate : {counting} , until : {state["epochs"]}')
        while state["epochs"] < 100:
            memory_alpha = []
            print("epoch:" + str(state["epochs"]))

            # monitor_value
            total_avg_reward = 0
            total_avg_scalar_reward = 0
            avg_origin_loss = 0
            all_avg_KD_rate = 0
            same = 0
            with tqdm(total=len(dataloader)) as pbar:
                for example_num, (x, targets) in enumerate(dataloader):
                    # if example_num==20:
                    #     break
                    student_KD.train()

                    if args.cuda:
                        x = x.cuda()
                        targets = targets.cuda()
                    if args.teacher_model is not None:
                        student_copy.train()
                        student_copy2.train()
                        # Set LR for this iteration
                        temp = {'state_dict': None, 'optim_dict': None}

                        temp['state_dict'] = copy.deepcopy(
                            student_KD.state_dict())
                        temp['optim_dict'] = copy.deepcopy(
                            KD_optimizer.state_dict())
                        #print('base_model from KD')

                        student_KD.load_state_dict(temp['state_dict'])
                        KD_optimizer.load_state_dict(temp['optim_dict'])

                        student_copy.load_state_dict(temp['state_dict'])
                        copy_optimizer.load_state_dict(temp['optim_dict'])

                        student_copy2.load_state_dict(temp['state_dict'])
                        copy2_optimizer.load_state_dict(temp['optim_dict'])

                        utils.set_cyclic_lr(KD_optimizer, example_num,
                                            len(train_data) // args.batch_size,
                                            args.cycles, args.min_lr, args.lr)
                        utils.set_cyclic_lr(copy_optimizer, example_num,
                                            len(train_data) // args.batch_size,
                                            args.cycles, args.min_lr, args.lr)
                        utils.set_cyclic_lr(copy2_optimizer, example_num,
                                            len(train_data) // args.batch_size,
                                            args.cycles, args.min_lr, args.lr)
                        # forward student and teacher  get output
                        student_KD_output, avg_student_KD_loss = utils.compute_loss(
                            student_KD,
                            x,
                            targets,
                            criterion,
                            compute_grad=False)
                        teacher_output, _ = utils.compute_loss(
                            teacher_model,
                            x,
                            targets,
                            criterion,
                            compute_grad=False)
                        # PG_state
                        diff_from_target = targets.detach(
                        ) - student_KD_output.detach()
                        diff_from_teacher = teacher_output.detach(
                        ) - student_KD_output.detach()
                        PG_state = torch.cat(
                            (diff_from_target, diff_from_teacher), 1)

                        # forward RL get alpha
                        alpha = policy_network(PG_state)
                        nograd_alpha = alpha.detach()

                        avg_KD_rate = torch.mean(nograd_alpha).item()
                        all_avg_KD_rate += avg_KD_rate / batch_num

                        KD_optimizer.zero_grad()
                        KD_outputs, KD_hard_loss, KD_loss, KD_soft_loss = utils.KD_compute_loss(
                            student_KD,
                            teacher_model,
                            x,
                            targets,
                            My_criterion,
                            alpha=nograd_alpha,
                            compute_grad=True,
                            KD_method=args.KD_method)
                        KD_optimizer.step()

                        copy_optimizer.zero_grad()
                        _, _, _, _ = utils.KD_compute_loss(
                            student_copy,
                            teacher_model,
                            x,
                            targets,
                            My_criterion,
                            alpha=1,
                            compute_grad=True,
                            KD_method=args.KD_method)
                        copy_optimizer.step()

                        copy2_optimizer.zero_grad()
                        _, _, _, _ = utils.KD_compute_loss(
                            student_copy2,
                            teacher_model,
                            x,
                            targets,
                            My_criterion,
                            alpha=0,
                            compute_grad=True,
                            KD_method=args.KD_method)
                        copy2_optimizer.step()

                        # calculate backwarded model MSE
                        backward_KD_loss = utils.loss_for_sample(
                            student_KD, x, targets)
                        backward_copy_loss = utils.loss_for_sample(
                            student_copy, x, targets)
                        backward_copy2_loss = utils.loss_for_sample(
                            student_copy2, x, targets)

                        # calculate rewards
                        rewards, same_num, before_decay = utils.get_rewards(
                            backward_KD_loss.detach(),
                            backward_copy_loss.detach(),
                            backward_copy2_loss.detach(),
                            backward_KD_loss.detach(), len(train_data),
                            state["epochs"] + 1)
                        same += same_num
                        rewards = rewards.detach()
                        avg_origin_loss += avg_student_KD_loss / batch_num

                        # avg_reward
                        avg_reward = torch.mean(rewards)
                        avg_scalar_reward = torch.mean(torch.abs(rewards))
                        total_avg_reward += avg_reward.item() / batch_num
                        total_avg_scalar_reward += avg_scalar_reward.item(
                        ) / batch_num
                        # append to memory_alpha
                        nograd_alpha = nograd_alpha.detach().cpu()
                        memory_alpha.append(nograd_alpha.numpy())

                        PG_optimizer.zero_grad()
                        _ = utils.RL_compute_loss(alpha, rewards, nn.MSELoss())
                        PG_optimizer.step()
                        # print info
                        # print(f'avg_KD_rate                 = {avg_KD_rate} ')
                        # print(f'student_KD_loss             = {avg_student_KD_loss}')
                        # print(f'backward_student_copy_loss  = {np.mean(backward_copy_loss.detach().cpu().numpy())}')
                        # print(f'backward_student_KD_loss    = {np.mean(backward_KD_loss.detach().cpu().numpy())}')
                        # print(f'backward_student_copy2_loss = {np.mean(backward_copy2_loss.detach().cpu().numpy())}')
                        # print(f'avg_reward                  = {avg_reward}')
                        # print(f'total_avg_reward            = {total_avg_reward}')
                        # print(f'same                        = {same}')

                        # add to tensorboard
                        writer.add_scalar("student_KD_loss",
                                          avg_student_KD_loss, state["step"])
                        writer.add_scalar(
                            "backward_student_KD_loss",
                            np.mean(backward_KD_loss.detach().cpu().numpy()),
                            state["step"])
                        writer.add_scalar("KD_loss", KD_loss, state["step"])
                        writer.add_scalar("KD_hard_loss", KD_hard_loss,
                                          state["step"])
                        writer.add_scalar("KD_soft_loss", KD_soft_loss,
                                          state["step"])
                        writer.add_scalar("avg_KD_rate", avg_KD_rate,
                                          state["step"])
                        writer.add_scalar("rewards", avg_reward, state["step"])
                        writer.add_scalar("scalar_rewards", avg_scalar_reward,
                                          state["step"])
                        writer.add_scalar("before_decay", before_decay,
                                          state["step"])
                    else:  # no KD training
                        utils.set_cyclic_lr(KD_optimizer, example_num,
                                            len(train_data) // args.batch_size,
                                            args.cycles, args.min_lr, args.lr)
                        KD_optimizer.zero_grad()
                        KD_outputs, KD_hard_loss = utils.compute_loss(
                            student_KD,
                            x,
                            targets,
                            nn.MSELoss(),
                            compute_grad=True)
                        KD_optimizer.step()
                        avg_origin_loss += KD_hard_loss / batch_num
                        writer.add_scalar("student_KD_loss", KD_hard_loss,
                                          state["step"])

                    ### save wav ####
                    if example_num % args.example_freq == 0:
                        input_centre = torch.mean(
                            x[0, :, student_KD.shapes["output_start_frame"]:
                              student_KD.shapes["output_end_frame"]],
                            0)  # Stereo not supported for logs yet

                        # target=torch.mean(targets[0], 0).cpu().numpy()
                        # pred=torch.mean(KD_outputs[0], 0).detach().cpu().numpy()
                        # inputs=input_centre.cpu().numpy()

                        writer.add_audio("input:",
                                         input_centre,
                                         state["step"],
                                         sample_rate=args.sr)
                        writer.add_audio("pred:",
                                         torch.mean(KD_outputs[0], 0),
                                         state["step"],
                                         sample_rate=args.sr)
                        writer.add_audio("target",
                                         torch.mean(targets[0], 0),
                                         state["step"],
                                         sample_rate=args.sr)

                    state["step"] += 1
                    pbar.update(1)
            # VALIDATE
            val_loss, val_metrics = validate(args, student_KD, criterion,
                                             val_data)
            print("ori VALIDATION FINISHED: LOSS: " + str(val_loss))

            choose_val = val_metrics
            if args.teacher_model is not None:
                for i in range(len(nograd_alpha)):
                    writer.add_scalar("KD_rate_" + str(i), nograd_alpha[i],
                                      state["epochs"])
                print(f'all_avg_KD_rate = {all_avg_KD_rate}')
                writer.add_scalar("all_avg_KD_rate", all_avg_KD_rate,
                                  state["epochs"])
                # writer.add_scalar("val_loss_copy", val_loss_copy, state["epochs"])
                writer.add_scalar("total_avg_reward", total_avg_reward,
                                  state["epochs"])
                writer.add_scalar("total_avg_scalar_reward",
                                  total_avg_scalar_reward, state["epochs"])

                RL_checkpoint_path = os.path.join(
                    checkpoint_dir, "RL_checkpoint_" + str(state["epochs"]))
                utils.save_model(policy_network, PG_optimizer, state,
                                 RL_checkpoint_path)
                PG_lr_scheduler.step()

            writer.add_scalar("same", same, state["epochs"])
            writer.add_scalar("avg_origin_loss", avg_origin_loss,
                              state["epochs"])
            writer.add_scalar("val_enhance_pesq", choose_val[0],
                              state["epochs"])
            writer.add_scalar("val_improve_pesq", choose_val[1],
                              state["epochs"])
            writer.add_scalar("val_enhance_stoi", choose_val[2],
                              state["epochs"])
            writer.add_scalar("val_improve_stoi", choose_val[3],
                              state["epochs"])
            writer.add_scalar("val_enhance_SISDR", choose_val[4],
                              state["epochs"])
            writer.add_scalar("val_improve_SISDR", choose_val[5],
                              state["epochs"])
            # writer.add_scalar("val_COPY_pesq",val_metrics_copy[0], state["epochs"])
            writer.add_scalar("val_loss", val_loss, state["epochs"])

            # Set up training state dict that will also be saved into checkpoints
            checkpoint_path = os.path.join(
                checkpoint_dir, "checkpoint_" + str(state["epochs"]))
            if choose_val[0] < state["best_pesq"]:
                state["worse_epochs"] += 1
            else:
                print("MODEL IMPROVED ON VALIDATION SET!")
                state["worse_epochs"] = 0
                state["best_pesq"] = choose_val[0]
                state["best_checkpoint"] = checkpoint_path

            # CHECKPOINT
            print("Saving model...")
            utils.save_model(student_KD, KD_optimizer, state, checkpoint_path)
            print('dump alpha_memory')
            with open(os.path.join(log_dir, 'alpha_' + str(state["epochs"])),
                      "wb") as fp:  #Pickling
                pickle.dump(memory_alpha, fp)

            state["epochs"] += 1
        writer.close()
        info = args.model_name
        path = os.path.join(result_dir, info)
    else:
        PATH = args.load_model.split("/")
        info = PATH[-3] + "_" + PATH[-1]
        if (args.outside_test == True):
            info += "_outside_test"
        print(info)
        path = os.path.join(result_dir, info)

    # test_data = SeparationDataset(dataset, "test", args.sr, args.channels, student_KD.shapes, False, args.hdf_dir, audio_transform=crop_func)

    #### TESTING ####
    # Test loss
    print("TESTING")
    # eval metrics
    #ling_data=get_ling_data_list('/media/hd03/sutsaiwei_data/data/mydata/ling_data')
    #validate(args, student_KD, criterion, test_data)
    #test_metrics = ling_evaluate(args, ling_data['noisy'], student_KD)
    #test_metrics = evaluate_without_noisy(args, dataset["test"], student_KD)
    test_metrics = evaluate(args, dataset["test"], student_KD)
    test_pesq = test_metrics['pesq']
    test_stoi = test_metrics['stoi']
    test_SISDR = test_metrics['SISDR']
    test_noise = test_metrics['noise']

    if not os.path.exists(path):
        os.makedirs(path)
    utils.save_result(test_pesq, path, "pesq")
    utils.save_result(test_stoi, path, "stoi")
    utils.save_result(test_SISDR, path, "SISDR")
    utils.save_result(test_noise, path, "noise")