Ejemplo n.º 1
0
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_json,
                              sample_rate=args.sample_rate,
                              segment_length=args.segment_length)

    cv_dataset = AudioDataset(
        args.valid_json,
        sample_rate=args.sample_rate,
        segment_length=args.segment_length,
    )

    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=args.batch_size,
                                shuffle=args.shuffle,
                                num_workers=args.num_workers)

    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=args.batch_size,
                                num_workers=0)

    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = ConvTasNet(args.N,
                       args.L,
                       args.B,
                       args.H,
                       args.P,
                       args.X,
                       args.R,
                       args.C,
                       norm_type=args.norm_type,
                       causal=args.causal,
                       mask_nonlinear=args.mask_nonlinear)
    print(model)
    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    lr = args.lr / args.batch_per_step
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
Ejemplo n.º 2
0
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_dir,
                              args.batch_size,
                              sample_rate=args.sample_rate,
                              segment=args.segment)
    cv_dataset = AudioDataset(
        args.valid_dir,
        batch_size=1,  # 1 -> use less GPU memory to do cv
        sample_rate=args.sample_rate,
        segment=-1,
        cv_maxlen=args.cv_maxlen)  # -1 -> use full audio
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                shuffle=args.shuffle,
                                num_workers=4)
    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=1,
                                num_workers=4,
                                pin_memory=True)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = ConvTasNet(args.N,
                       args.L,
                       args.B,
                       args.H,
                       args.P,
                       args.X,
                       args.R,
                       args.C,
                       norm_type=args.norm_type,
                       causal=args.causal,
                       mask_nonlinear=args.mask_nonlinear)
    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
def main(train_dir,batch_size,sample_rate, segment,valid_dir,cv_maxlen,shuffle,num_workers,N, L, B, H, P, X, R, C,norm_type, causal, mask_nonlinear,use_cuda,optimizer,lr,momentum,l2):
     # Construct Solver
    # data
    tr_dataset = AudioDataset(train_dir, batch_size,
                              sample_rate=sample_rate, segment=segment)
    cv_dataset = AudioDataset(valid_dir, batch_size=1,  # 1 -> use less GPU memory to do cv
                              sample_rate=sample_rate,
                              segment=-1, cv_maxlen=cv_maxlen)  # -1 -> use full audio
    tr_loader = AudioDataLoader(tr_dataset, batch_size=1,
                                shuffle=shuffle,
                                num_workers=num_workers)
    cv_loader = AudioDataLoader(cv_dataset, batch_size=1,
                                num_workers=0)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = ConvTasNet(N, L, B, H, P, X, R, C, 
                       norm_type=norm_type, causal=causal,
                       mask_nonlinear=mask_nonlinear)
    print(model)
    if use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    if optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                     lr=lr,
                                     momentum=momentum,
                                     weight_decay=l2)
    elif optimizer == 'adam':
     #fatemeh: change optimizier to optimizer
        optimizer = torch.optim.Adam(model.parameters(),
                                      lr=lr,
                                      weight_decay=l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizer, use_cuda,epochs,half_lr,early_stop,max_norm,save_folder,checkpoint,continue_from,model_path,print_freq,visdom,visdom_epoch,visdom_id)
    solver.train()
Ejemplo n.º 4
0
def main(args):
    # Construct Solver
    # data

    if args.continue_from == '':
        return
    ev_dataset = EvalAllDataset(args.train_dir,
                                args.mix_json,
                                args.batch_size,
                                sample_rate=args.sample_rate)

    ev_loader = EvalAllDataLoader(ev_dataset,
                                  batch_size=1,
                                  num_workers=args.num_workers)

    data = {'tr_loader': None, 'ev_loader': ev_loader}
    # SEP model
    sep_model = ConvTasNet(args.N,
                           args.L,
                           args.B,
                           args.H,
                           args.P,
                           args.X,
                           args.R,
                           args.C,
                           norm_type=args.norm_type,
                           causal=args.causal,
                           mask_nonlinear=args.mask_nonlinear)

    # ASR model
    asr_model = AttentionModel(args.NUM_HIDDEN_NODES, args.NUM_ENC_LAYERS,
                               args.NUM_CLASSES)
    #print(model)
    if args.use_cuda:
        sep_model = torch.nn.DataParallel(sep_model)
        asr_model = torch.nn.DataParallel(asr_model)
        sep_model.cuda()
        asr_model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        sep_optimizier = torch.optim.SGD(sep_model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.l2)
        asr_optimizier = torch.optim.SGD(asr_model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.l2)
    elif args.optimizer == 'adam':
        sep_optimizier = torch.optim.Adam(sep_model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.l2)
        asr_optimizier = torch.optim.Adam(asr_model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data,
                    sep_model,
                    asr_model,
                    sep_optimizier,
                    asr_optimizier,
                    args,
                    DEVICE,
                    ev=True)
    solver.eval(args.EOS_ID)
Ejemplo n.º 5
0
def separate(args):
    if args.mix_dir is None and args.mix_json is None:
        print("Must provide mix_dir or mix_json! When providing mix_dir, "
              "mix_json is ignored.")

    # Load model
    model = ConvTasNet(256,
                       20,
                       256,
                       512,
                       3,
                       8,
                       4,
                       2,
                       norm_type="gLN",
                       causal=0,
                       mask_nonlinear="relu")
    model.cuda()
    model.load_state_dict(torch.load(args.model_path)['sep_state_dict'])
    print(model)
    model.eval()

    # Load data
    eval_dataset = EvalDataset(args.mix_dir,
                               args.mix_json,
                               batch_size=args.batch_size,
                               sample_rate=args.sample_rate)
    eval_loader = EvalDataLoader(eval_dataset, batch_size=1)
    os.makedirs(args.out_dir, exist_ok=True)

    def write(inputs, filename, sr=args.sample_rate):
        #librosa.output.write_wav(filename, inputs, sr)# norm=True)
        #librosa.output.write_wav(filename, inputs, sr, norm=True)
        #print(inputs)
        inputs = inputs / max(np.abs(inputs))
        #print(inputs)

        sf.write(filename, inputs, sr)
        #sf.write(filename, inputs, sr, 'PCM_16')

    with torch.no_grad():
        for (i, data) in enumerate(eval_loader):
            # Get batch data
            mixture, mix_lengths, filenames = data
            if args.use_cuda:
                mixture, mix_lengths = mixture.cuda(), mix_lengths.cuda()
            # Forward
            estimate_source = model(mixture)  # [B, C, T]
            # Remove padding and flat
            flat_estimate = remove_pad(estimate_source, mix_lengths)
            mixture = remove_pad(mixture, mix_lengths)
            # Write result
            for i, filename in enumerate(filenames):
                filename = os.path.join(
                    args.out_dir,
                    os.path.basename(filename).strip('.wav'))
                write(mixture[i], filename + '.wav')
                C = flat_estimate[i].shape[0]
                for c in range(C):
                    write(flat_estimate[i][c],
                          filename + '_s{}.wav'.format(c + 1))