コード例 #1
0
def main():

    os.system('cp -r ../ConvTasNet "{0}"'.format(config.basePath +
                                                 '/savedCode'))

    model = DataParallel(ConvTasNet(C=2))
    dataloader = AVSpeech('test')
    dataloader = DataLoader(dataloader,
                            batch_size=config.batchsize['test'],
                            num_workers=config.num_workers['test'],
                            worker_init_fn=init_fn)
    loss_func = SISNRPIT()

    if config.use_cuda:
        model = model.cuda()

    config.pretrained_test = [
        '/home/SharedData/Pragya/ModelsToUse/AudioOnlyConvTasNet.pth',
    ]

    for cur_test in config.pretrained_test:

        print('Currently working on: ', cur_test.split('/')[-1])

        model.load_state_dict(torch.load(cur_test)['model_state_dict'])

        total_loss = test(
            cur_test.split('/')[-1].split('.')[0], model, dataloader,
            loss_func)

        torch.cuda.empty_cache()

        print('Average Loss for ',
              cur_test.split('/')[-1], 'is: ', np.mean(total_loss))
def main():

    os.system('cp -r ../ConvTasNet "{0}"'.format(config.basePath +
                                                 '/savedCode'))

    model = DataParallel(ConvTasNet(C=2))

    print('Total Parameters: ', sum(p.numel() for p in model.parameters()))

    dataloader = AVSpeech('train')

    loss_func = SISNRPIT()

    if config.use_cuda:
        model = model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr[1])

    if config.pretrained:
        saved_model = torch.load(config.pretrained_train)
        model.load_state_dict(saved_model['model_state_dict'])
        optimizer.load_state_dict(saved_model['optimizer_state_dict'])
        saved_loss = np.load(config.loss_path).tolist()
    else:
        saved_loss = None

    dataloader = DataLoader(dataloader,
                            batch_size=config.batchsize['train'],
                            num_workers=config.num_workers['train'],
                            worker_init_fn=init_fn)

    train(model, dataloader, optimizer, loss_func, saved_loss)
コード例 #3
0
def main():

    os.system('cp -r ../Oracle "{0}"'.format(config.basePath + '/savedCode'))

    model = DataParallel(ConvTasNet(C=2, test_with_asr=True))
    dataloader = AVSpeech('test')
    dataloader = DataLoader(dataloader,
                            batch_size=config.batchsize['test'],
                            num_workers=config.num_workers['test'],
                            worker_init_fn=init_fn)

    if config.use_cuda:
        model = model.cuda()

    config.pretrained_test = [
        '/home/SharedData/Pragya/Experiments/Oracle/2020-05-20 15:23:34.411560/116662.pth'
    ]

    for cur_test in config.pretrained_test:

        print('Currently working on: ', cur_test.split('/')[-1])

        model.load_state_dict(torch.load(cur_test)['model_state_dict'])

        total_loss = test(model, dataloader)

        torch.cuda.empty_cache()

        print('Average Loss for ',
              cur_test.split('/')[-1], 'is: ', np.mean(total_loss))
コード例 #4
0
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model
    model = ConvTasNet.load_model(args.model_path)
    print(model)
    model.eval()
    #if args.use_cuda:
    if True:
        model.cuda()

    # Load data
    dataset = AudioDataset(args.data_dir,
                           args.batch_size,
                           sample_rate=args.sample_rate,
                           segment=-1)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            #if args.use_cuda:
            if True:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()
            # Forward
            estimate_source = model(padded_mixture)  # [B, C, T]
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi /
                                                        total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi /
                                                      total_cnt))
def main():

    os.system('cp -r ../Oracle "{0}"'.format(config.basePath + '/savedCode'))

    convtasnet_audio_with_asr_model = DataParallel(
        ConvTasNet(C=2, test_with_asr=True)).cuda()

    convtasnet_audio_without_asr_model = DataParallel(
        ConvTasNet(C=2, asr_addition=False)).cuda()
    dataloader = AVSpeech('test')
    dataloader = DataLoader(dataloader,
                            batch_size=config.batchsize['test'],
                            num_workers=config.num_workers['test'],
                            worker_init_fn=init_fn)
    loss_func = SISNRPIT()

    convtasnet_model = config.convtasnet_audio_model
    convtasnet_asr_model = [
        '/home/SharedData/Pragya/Experiments/Oracle/2020-05-20 15:23:34.411560/116662.pth'
    ]

    for conv_asr_test in convtasnet_asr_model:
        print('Currently working convtasnet on: ',
              convtasnet_model.split('/')[-1])
        print('Currently working E2ESpeechSaparation on: ',
              conv_asr_test.split('/')[-1])

        convtasnet_audio_without_asr_model.load_state_dict(
            torch.load(convtasnet_model)['model_state_dict'])
        convtasnet_audio_with_asr_model.load_state_dict(
            torch.load(conv_asr_test)['model_state_dict'])

        total_loss = test(convtasnet_audio_without_asr_model,
                          convtasnet_audio_with_asr_model, dataloader,
                          loss_func)

        torch.cuda.empty_cache()

        print('Average Loss for ',
              conv_asr_test.split('/')[-1], 'is: ', np.mean(total_loss))
コード例 #6
0
def separate(args):
    if args.mix_dir is None and args.mix_json is None:
        print("Must provide mix_dir or mix_json! When providing mix_dir, "
              "mix_json is ignored.")

    # Load model
    model = ConvTasNet.load_model(args.model_path)
    print(model)
    model.eval()
    if args.use_cuda:
        model.cuda()

    # Load data
    eval_dataset = EvalDataset(args.mix_dir,
                               args.mix_json,
                               batch_size=args.batch_size,
                               sample_rate=args.sample_rate)
    eval_loader = EvalDataLoader(eval_dataset, batch_size=1)
    os.makedirs(args.out_dir, exist_ok=True)

    def write(inputs, filename, sr=args.sample_rate):
        #librosa.output.write_wav(filename, inputs, sr)# norm=True)
        #librosa.output.write_wav(filename, inputs, sr, norm=True)
        #print(inputs)
        #inputs = inputs / max(np.abs(inputs))
        inputs = inputs / (2 * max(np.abs(inputs)))
        #print(inputs)

        sf.write(filename, inputs, sr)
        #sf.write(filename, inputs, sr, 'PCM_16')

    with torch.no_grad():
        for (i, data) in enumerate(eval_loader):
            # Get batch data
            mixture, mix_lengths, filenames = data
            if args.use_cuda:
                mixture, mix_lengths = mixture.cuda(), mix_lengths.cuda()
            # Forward
            estimate_source = model(mixture)  # [B, C, T]
            # Remove padding and flat
            flat_estimate = remove_pad(estimate_source, mix_lengths)
            mixture = remove_pad(mixture, mix_lengths)
            # Write result
            for i, filename in enumerate(filenames):
                filename = os.path.join(
                    args.out_dir,
                    os.path.basename(filename).strip('.wav'))
                write(mixture[i], filename + '.wav')
                C = flat_estimate[i].shape[0]
                for c in range(C):
                    write(flat_estimate[i][c],
                          filename + '_s{}.wav'.format(c + 1))
コード例 #7
0
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_json,
                              sample_rate=args.sample_rate,
                              segment_length=args.segment_length)

    cv_dataset = AudioDataset(
        args.valid_json,
        sample_rate=args.sample_rate,
        segment_length=args.segment_length,
    )

    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=args.batch_size,
                                shuffle=args.shuffle,
                                num_workers=args.num_workers)

    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=args.batch_size,
                                num_workers=0)

    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = ConvTasNet(args.N,
                       args.L,
                       args.B,
                       args.H,
                       args.P,
                       args.X,
                       args.R,
                       args.C,
                       norm_type=args.norm_type,
                       causal=args.causal,
                       mask_nonlinear=args.mask_nonlinear)
    print(model)
    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    lr = args.lr / args.batch_per_step
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #8
0
ファイル: train.py プロジェクト: b06502162Lu/HW3-1_Lu
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_dir,
                              args.batch_size,
                              sample_rate=args.sample_rate,
                              segment=args.segment)
    cv_dataset = AudioDataset(
        args.valid_dir,
        batch_size=1,  # 1 -> use less GPU memory to do cv
        sample_rate=args.sample_rate,
        segment=-1,
        cv_maxlen=args.cv_maxlen)  # -1 -> use full audio
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                shuffle=args.shuffle,
                                num_workers=4)
    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=1,
                                num_workers=4,
                                pin_memory=True)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = ConvTasNet(args.N,
                       args.L,
                       args.B,
                       args.H,
                       args.P,
                       args.X,
                       args.R,
                       args.C,
                       norm_type=args.norm_type,
                       causal=args.causal,
                       mask_nonlinear=args.mask_nonlinear)
    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #9
0
def main(argv):
    model = ConvTasNet.make(get_model_param())
    dataset = Dataset(FLAGS.dataset_path, max_decoded=FLAGS.max_decoded)
    checkpoint_dir = FLAGS.checkpoint

    epoch = 0
    if path.exists(checkpoint_dir):
        checkpoints = [name for name in listdir(
            checkpoint_dir) if "ckpt" in name]
        checkpoints.sort()
        checkpoint_name = checkpoints[-1].split(".")[0]
        epoch = int(checkpoint_name) + 1
        model.load_weights(f"{checkpoint_dir}/{checkpoint_name}.ckpt")

    epochs_to_inc = FLAGS.epochs
    while epochs_to_inc == None or epochs_to_inc > 0:
        print(f"Epoch: {epoch}")
        history = model.fit(dataset.make_dataset(get_dataset_param()))
        model.save_weights(f"{checkpoint_dir}/{epoch:05d}.ckpt")
        epoch += 1
        if epochs_to_inc != None:
            epochs_to_inc -= 1
        model.param.save(f"{checkpoint_dir}/config.txt")
        model.save(f"{checkpoint_dir}/model")
コード例 #10
0
def main(train_dir,batch_size,sample_rate, segment,valid_dir,cv_maxlen,shuffle,num_workers,N, L, B, H, P, X, R, C,norm_type, causal, mask_nonlinear,use_cuda,optimizer,lr,momentum,l2):
     # Construct Solver
    # data
    tr_dataset = AudioDataset(train_dir, batch_size,
                              sample_rate=sample_rate, segment=segment)
    cv_dataset = AudioDataset(valid_dir, batch_size=1,  # 1 -> use less GPU memory to do cv
                              sample_rate=sample_rate,
                              segment=-1, cv_maxlen=cv_maxlen)  # -1 -> use full audio
    tr_loader = AudioDataLoader(tr_dataset, batch_size=1,
                                shuffle=shuffle,
                                num_workers=num_workers)
    cv_loader = AudioDataLoader(cv_dataset, batch_size=1,
                                num_workers=0)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = ConvTasNet(N, L, B, H, P, X, R, C, 
                       norm_type=norm_type, causal=causal,
                       mask_nonlinear=mask_nonlinear)
    print(model)
    if use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    if optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                     lr=lr,
                                     momentum=momentum,
                                     weight_decay=l2)
    elif optimizer == 'adam':
     #fatemeh: change optimizier to optimizer
        optimizer = torch.optim.Adam(model.parameters(),
                                      lr=lr,
                                      weight_decay=l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizer, use_cuda,epochs,half_lr,early_stop,max_norm,save_folder,checkpoint,continue_from,model_path,print_freq,visdom,visdom_epoch,visdom_id)
    solver.train()
コード例 #11
0
def main(args):
    # Construct Solver
    # data

    if args.continue_from == '':
        return
    ev_dataset = EvalAllDataset(args.train_dir,
                                args.mix_json,
                                args.batch_size,
                                sample_rate=args.sample_rate)

    ev_loader = EvalAllDataLoader(ev_dataset,
                                  batch_size=1,
                                  num_workers=args.num_workers)

    data = {'tr_loader': None, 'ev_loader': ev_loader}
    # SEP model
    sep_model = ConvTasNet(args.N,
                           args.L,
                           args.B,
                           args.H,
                           args.P,
                           args.X,
                           args.R,
                           args.C,
                           norm_type=args.norm_type,
                           causal=args.causal,
                           mask_nonlinear=args.mask_nonlinear)

    # ASR model
    asr_model = AttentionModel(args.NUM_HIDDEN_NODES, args.NUM_ENC_LAYERS,
                               args.NUM_CLASSES)
    #print(model)
    if args.use_cuda:
        sep_model = torch.nn.DataParallel(sep_model)
        asr_model = torch.nn.DataParallel(asr_model)
        sep_model.cuda()
        asr_model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        sep_optimizier = torch.optim.SGD(sep_model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.l2)
        asr_optimizier = torch.optim.SGD(asr_model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.l2)
    elif args.optimizer == 'adam':
        sep_optimizier = torch.optim.Adam(sep_model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.l2)
        asr_optimizier = torch.optim.Adam(asr_model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data,
                    sep_model,
                    asr_model,
                    sep_optimizier,
                    asr_optimizier,
                    args,
                    DEVICE,
                    ev=True)
    solver.eval(args.EOS_ID)
コード例 #12
0
def main(argv):
    checkpoint_dir = FLAGS.checkpoint
    if not path.exists(checkpoint_dir):
        raise ValueError(f"'{checkpoint_dir}' does not exist")

    checkpoints = [name for name in listdir(checkpoint_dir) if "ckpt" in name]
    if not checkpoints:
        raise ValueError(f"No checkpoint exists")
    checkpoints.sort()
    checkpoint_name = checkpoints[-1].split(".")[0]

    param = ConvTasNetParam.load(f"{checkpoint_dir}/config.txt")
    model = ConvTasNet.make(param)
    model.load_weights(f"{checkpoint_dir}/{checkpoint_name}.ckpt")

    video_id = FLAGS.video_id

    ydl_opts = {
        "format":
        "bestaudio/best",
        "postprocessors": [{
            "key": "FFmpegExtractAudio",
            "preferredcodec": "wav",
            "preferredquality": "44100",
        }],
        "outtmpl":
        "%(title)s.wav",
        "progress_hooks": [youtube_dl_hook],
    }

    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        info = ydl.extract_info(video_id, download=False)
        status = ydl.download([video_id])

    title = info.get("title", None)
    filename = title + ".wav"
    audio, sr = librosa.load(filename, sr=44100, mono=True)

    num_samples = audio.shape[0]

    num_portions = (num_samples - param.overlap) // (param.That *
                                                     (param.L - param.overlap))

    num_samples_output = num_portions * param.That * (param.L - param.overlap)

    num_samples = num_samples_output + param.overlap

    if FLAGS.interpolate:

        def filter_gen(n):
            if n < param.overlap:
                return n / param.overlap
            elif n > param.L - param.overlap:
                return (param.L - n) / param.overlap
            else:
                return 1

        output_filter = np.array([filter_gen(n) for n in range(param.L)])

    print("predicting...")

    audio = audio[:num_samples]

    model_input = np.zeros((num_portions, param.That, param.L))

    for i in range(num_portions):
        for j in range(param.That):
            begin = (i * param.That + j) * (param.L - param.overlap)
            end = begin + param.L
            model_input[i][j] = audio[begin:end]

    separated = model.predict(model_input)
    separated = np.transpose(separated, (1, 0, 2, 3))

    if FLAGS.interpolate:
        separated = output_filter * separated
        overlapped = separated[:, :, :, (param.L - param.overlap):]
        overlapped = np.pad(overlapped,
                            pad_width=((0, 0), (0, 0), (0, 0),
                                       (0, param.L - 2 * param.overlap)),
                            mode="constant",
                            constant_values=0)
        overlapped = np.reshape(overlapped, (param.C, num_samples_output))
        overlapped[:, 1:] = overlapped[:, :-1]
        overlapped[:, 0] = 0

    separated = separated[:, :, :, :(param.L - param.overlap)]
    separated = np.reshape(separated, (param.C, num_samples_output))

    if FLAGS.interpolate:
        separated += overlapped

    print("saving...")

    for idx, stem in enumerate(Dataset.STEMS):
        sf.write(f"{title}_{stem}.wav", separated[idx], sr)
コード例 #13
0
ファイル: main.py プロジェクト: Owen864720655/Conv_TasNet
    def __init__(self, training):
        super(ModelAccess, self).__init__()
        self.training = training
        self.main_config = ConfigTables()
        self.ioconfig = self.main_config.io_config
        self.dataconfig = self.main_config.data_config
        self.trainconfig = self.main_config.train_config
        self.modelconfig = self.main_config.model_config
        self.mulaw = self.dataconfig["mulaw"]
        self.audio_length = self.dataconfig["audio_length"]
        self.lr = self.trainconfig["lr"]
        self.batch_size = self.trainconfig["batch_size"]
        self.optimizer = self.trainconfig["optimizer"].lower()
        assert self.optimizer in ["sgd", "adma", "rmsprop"], "Not include other optimzier"
        if self.optimizer == "sgd":
            self.optimizer = tf.train.GradientDescentOptimizer(self.lr)
        elif self.optimizer == "adma":
            self.optimizer = tf.train.AdamOptimizer(self.lr)
        else:
            self.optimizer = tf.train.RMSPropOptimizer(self.lr)
        self.ckpt_dir = self.trainconfig["ckpt_dir"]
        self.epoches = self.trainconfig["epoches"]
        self.output_dir = self.dataconfig["output_dir"]
        os.makedirs(self.ckpt_dir, exist_ok=True)
        os.makedirs(self.output_dir, exist_ok=True)
        self.dataset = RecordMaker(training).dataset
        self.max_to_keep = self.trainconfig["max_to_keep"]
        self.filters_e = self.modelconfig["filters_e"]
        self.plot_pertire = self.trainconfig["plot_pertire"]
        self.kernel_size_e = self.modelconfig["kernel_size_e"]
        self.bottle_filter = self.modelconfig["bottle_filter"]
        self.filters_block = self.modelconfig["filters_block"]
        self.kernel_size_block = self.modelconfig["kernel_size_block"]
        self.num_conv_block = self.modelconfig["num_conv_block"]
        self.number_repeat = self.modelconfig["number_repeat"]
        self.spk_num = self.modelconfig["spk_num"]
        self.norm_type = self.modelconfig["norm_type"]
        self.causal = self.modelconfig["causal"]
        self.mask_nonlinear = self.modelconfig["mask_nonlinear"]
        self.savemodel_periter = self.trainconfig["savemodel_periter"]
        self.convtasnet = ConvTasNet(filters_e=self.filters_e,
                                     kernel_size_e=self.kernel_size_e,
                                     bottle_filter=self.bottle_filter,
                                     filters_block=self.filters_block,
                                     kernel_size_block=self.kernel_size_block,
                                     num_conv_block=self.num_conv_block,
                                     number_repeat=self.number_repeat,
                                     spk_num=self.spk_num,
                                     norm_type=self.norm_type,
                                     causal=self.causal,
                                     mask_nonlinear=self.mask_nonlinear,
                                     speech_length=self.audio_length)
        self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer,
                                              contasnet=self.convtasnet)
        self.ckpt_manager = tf.contrib.checkpoint.CheckpointManager(
            self.checkpoint,
            directory=self.ckpt_dir,
            max_to_keep=self.max_to_keep)

        self.checkpoint.restore(tf.train.latest_checkpoint(self.ckpt_dir))

        if training:
            self.train_epochs()
        else:
            print("finish load model!")
def get_initial_model_optimizer():
    """
    Loading pretrained model of convtasnet and ASR, for domain translation of asr's features
    to convtasnet, Domain translation block is used
    """
    from ETESpeechRecognition.model import E2E as ASR
    from domainTranslation import DomainTranslation
    import ETESpeechRecognition.config as asrConfig

    # loading convtasnet model
    trained_convtasnet_audio_model = torch.load(
        config.convtasnet_audio_model, map_location=torch.device('cpu'))

    convtasnet_audio_with_asr_model = DataParallel(ConvTasNet(C=2))

    model_state_dict = trained_convtasnet_audio_model['model_state_dict']

    # adding random weights to model for new block addition
    model_state_dict['module.separator.network.0.gamma'] = torch.cat([
        model_state_dict['module.separator.network.0.gamma'],
        torch.randn(size=[1, 512, 1])
    ],
                                                                     dim=1)
    model_state_dict['module.separator.network.0.beta'] = torch.cat([
        model_state_dict['module.separator.network.0.beta'],
        torch.randn(size=[1, 512, 1])
    ],
                                                                    dim=1)
    model_state_dict['module.separator.network.1.weight'] = torch.cat([
        model_state_dict['module.separator.network.1.weight'],
        torch.randn(size=[512, 512, 1])
    ],
                                                                      dim=1)

    convtasnet_audio_with_asr_model.load_state_dict(
        trained_convtasnet_audio_model['model_state_dict'])

    print('Total Parameters in ConvTasNet without ASR model: ',
          sum(p.numel() for p in convtasnet_audio_with_asr_model.parameters()))

    convtasnet_audio_with_asr_model.module.domainTranslation = DomainTranslation(
    )

    optimizer_init = torch.optim.Adam(
        convtasnet_audio_with_asr_model.parameters(), lr=config.lr[1])

    if config.use_cuda:
        convtasnet_audio_with_asr_model = convtasnet_audio_with_asr_model.cuda(
        )

    # Loading ASR model
    asr_model = DataParallel(
        ASR(idim=80, odim=5002, args=asrConfig.ModelArgs(), get_features=True))
    if config.use_cuda:
        trained_asr_model = torch.load(config.asr_model)
    else:
        trained_asr_model = torch.load(config.asr_model,
                                       map_location=torch.device('cpu'))

    asr_model.load_state_dict(trained_asr_model['model'])

    convtasnet_audio_with_asr_model.module.asr = asr_model

    print('Total Parameters in ConvTasNet with ASR model: ',
          sum(p.numel() for p in convtasnet_audio_with_asr_model.parameters()))

    return convtasnet_audio_with_asr_model, optimizer_init
コード例 #15
0
        'P': args.P,
        'X': args.X,
        'R': args.R,
        'C': args.C,
        'norm_type': args.norm_type,
        'causal': args.causal,
        'mask_nonlinear': args.mask_nonlinear
    }

    train_args = {
        'lr': args.lr,
        'batch_size': args.batch_size,
        'epochs': args.epochs
    }

    model = ConvTasNet(**model_args)

    if args.evaluate == 0 and args.separate == 0:
        dataset = AudioDataset(args.data_dir, sr=args.sr, mode='train', seq_len=args.seq_len, verbose=0, voice_only=args.voice_only)

        print('DataLoading Done')

        train(model, dataset, **train_args)
    elif args.evaluate == 1:
        model.load_state_dict(torch.load(args.model, map_location='cpu'))

        dataset = AudioDataset(args.data_dir, sr=args.sr, mode='test', seq_len=args.seq_len, verbose=0, voice_only=args.voice_only)

        evaluate(model, dataset, args.batch_size, 0, args.cal_sdr)
    else:
        model.load_state_dict(torch.load(args.model, map_location='cpu'))