Esempio n. 1
0
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training Dual-Path-RNN')
    parser.add_argument('--opt', type=str, help='Path to option YAML file.', default='config/Dual_RNN/train.yaml')
    args = parser.parse_args()
    opt = option.parse(args.opt)
    # build dataloader
    print('Building the dataloader of Dual-Path-RNN')
    train_dataloader, val_dataloader = make_dataloader(opt)


    set_logger.setup_logger(opt['logger']['name'], opt['logger']['path'],
                            screen=opt['logger']['screen'], tofile=opt['logger']['tofile'], tensorboard=opt['logger']['tofile'],
                            num_iter_train=len(train_dataloader), num_iter_val=len(val_dataloader), sr=opt['datasets']['audio_setting']['sample_rate'])
    logger = logging.getLogger(opt['logger']['name'])
    # build model
    logger.info("Building the model of Dual-Path-RNN")

    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))
    
    # build trainer
    logger.info('Building the Trainer of Dual-Path-RNN')
    trainer = trainer_Dual_RNN.Trainer(train_dataloader, val_dataloader, opt)
    trainer.run()
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training DANet')
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt)
    set_logger.setup_logger(opt['logger']['name'],
                            opt['logger']['path'],
                            screen=opt['logger']['screen'],
                            tofile=opt['logger']['tofile'])
    logger = logging.getLogger(opt['logger']['name'])

    logger.info("Building the model of DANet")
    danet = model.DANet(**opt['DANet'])

    logger.info("Building the optimizer of DANet")
    optimizer = make_optimizer(danet.parameters(), opt)

    logger.info('Building the dataloader of DANet')
    train_dataloader, val_dataloader = make_dataloader(opt)

    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))
    logger.info('Building the Trainer of DANet')
    trainer = Trainer(train_dataloader, val_dataloader, danet, optimizer, opt)
    trainer.run()
Esempio n. 3
0
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training Deep Clustering')
    parser.add_argument('--opt',
                        type=str,
                        help='Path to option YAML file.',
                        default='./config/train.yml')
    args = parser.parse_args()
    opt = option.parse(args.opt)
    set_logger.setup_logger(opt['logger']['name'],
                            opt['logger']['path'],
                            screen=opt['logger']['screen'],
                            tofile=opt['logger']['tofile'])
    logger = logging.getLogger(opt['logger']['name'])

    logger.info("Building the model of Deep Clustering")
    dpcl = model.DPCL(**opt['DPCL'])

    logger.info("Building the optimizer of Deep Clustering")
    optimizer = make_optimizer(dpcl.parameters(), opt)

    logger.info('Building the dataloader of Deep Clustering')
    train_dataloader, val_dataloader = make_dataloader(opt)

    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))
    logger.info('Building the Trainer of Deep Clustering')
    trainer = Trainer(train_dataloader, val_dataloader, dpcl, optimizer, opt)
    trainer.run()
Esempio n. 4
0
 def __init__(self, mix_path, yaml_path, model, gpuid):
     super(Separation, self).__init__()
     self.mix = read_wav(mix_path)
     opt = parse(yaml_path)
     net = Dual_RNN_model(**opt['Dual_Path_RNN'])
     dicts = torch.load(model, map_location='cpu')
     net.load_state_dict(dicts["model_state_dict"])
     setup_logger(opt['logger']['name'], opt['logger']['path'],
                         screen=opt['logger']['screen'], tofile=opt['logger']['tofile'])
     self.logger = logging.getLogger(opt['logger']['name'])
     self.logger.info('Load checkpoint from {}, epoch {: d}'.format(model, dicts["epoch"]))
     self.net=net
     self.gpuid = gpuid
Esempio n. 5
0
 def __init__(self, mix_path, yaml_path, model, gpuid):
     super(Separation, self).__init__()
     self.mix = AudioReader(mix_path, sample_rate=8000)
     opt = parse(yaml_path)
     net = Conv_TasNet(**opt['Conv_Tasnet'])
     dicts = torch.load(model, map_location='cpu')
     net.load_state_dict(dicts["model_state_dict"])
     setup_logger(opt['logger']['name'], opt['logger']['path'],
                         screen=opt['logger']['screen'], tofile=opt['logger']['tofile'])
     self.logger = logging.getLogger(opt['logger']['name'])
     self.logger.info('Load checkpoint from {}, epoch {: d}'.format(model, dicts["epoch"]))
     self.net=net.cuda()
     self.device=torch.device('cuda:{}'.format(
         gpuid[0]) if len(gpuid) > 0 else 'cpu')
     self.gpuid=tuple(gpuid)
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training Dual-Path-RNN')
    parser.add_argument('--opt',
                        type=str,
                        help='Path to option YAML file.',
                        default='config/Dual_RNN/train.yaml')
    args = parser.parse_args()
    opt = option.parse(args.opt)
    set_logger.setup_logger(opt['logger']['name'],
                            opt['logger']['path'],
                            screen=opt['logger']['screen'],
                            tofile=opt['logger']['tofile'])
    logger = logging.getLogger(opt['logger']['name'])
    # build model
    logger.info("Building the model of Dual-Path-RNN")
    Dual_Path_RNN = model_rnn.Dual_RNN_model(**opt['Dual_Path_RNN'])
    # build optimizer
    logger.info("Building the optimizer of Dual-Path-RNN")
    optimizer = make_optimizer(Dual_Path_RNN.parameters(), opt)
    # build dataloader
    logger.info('Building the dataloader of Dual-Path-RNN')
    train_dataloader, val_dataloader = make_dataloader(opt)

    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))
    # build scheduler
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  factor=opt['scheduler']['factor'],
                                  patience=opt['scheduler']['patience'],
                                  verbose=True,
                                  min_lr=opt['scheduler']['min_lr'])

    # build trainer
    logger.info('Building the Trainer of Dual-Path-RNN')
    trainer = trainer_Dual_RNN.Trainer(train_dataloader, val_dataloader,
                                       Dual_Path_RNN, optimizer, scheduler,
                                       opt)
    trainer.run()
Esempio n. 7
0
    def __init__(self, mix_json, aux_json, ref_json, yaml_path, model, gpuid):
        super(Separation, self).__init__()
        opt = parse(yaml_path)
        self.data_loader = Datasets_Test(
            mix_json,
            aux_json,
            ref_json,
            sample_rate=opt['datasets']['audio_setting']['sample_rate'],
            model=opt['model']['MODEL'])
        self.mix_infos = self.data_loader.mix
        self.total_wavs = len(self.mix_infos)
        self.model = opt['model']['MODEL']
        # Extraction and Suppression model
        if self.model == 'DPRNN_Speaker_Extraction' or self.model == 'DPRNN_Speaker_Suppression':
            net = model_function.Extractin_Suppression_Model(
                **opt['Dual_Path_Aux_Speaker'])
        # Separation model
        if self.model == 'DPRNN_Speech_Separation':
            net = model_function.Speech_Serapation_Model(
                **opt['Dual_Path_Aux_Speaker'])
        net = torch.nn.DataParallel(net)
        dicts = torch.load(model, map_location='cpu')
        net.load_state_dict(dicts["model_state_dict"])
        setup_logger(opt['logger']['name'],
                     opt['logger']['path'],
                     screen=opt['logger']['screen'],
                     tofile=opt['logger']['tofile'])
        self.logger = logging.getLogger(opt['logger']['name'])
        self.logger.info('Load checkpoint from {}, epoch {: d}'.format(
            model, dicts["epoch"]))
        self.net = net.cuda()

        self.device = torch.device(
            'cuda:{}'.format(gpuid[0]) if len(gpuid) > 0 else 'cpu')
        self.gpuid = tuple(gpuid)

        self.Output_wav = False
Esempio n. 8
0
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training Conv-TasNet')
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt)
    set_logger.setup_logger(opt['logger']['name'],
                            opt['logger']['path'],
                            screen=opt['logger']['screen'],
                            tofile=opt['logger']['tofile'])
    logger = logging.getLogger(opt['logger']['name'])
    # build model
    logger.info("Building the model of Conv-Tasnet")
    Conv_Tasnet = model.Conv_TasNet(**opt['Conv_Tasnet'])
    # build optimizer
    logger.info("Building the optimizer of Conv-Tasnet")
    optimizer = make_optimizer(Conv_Tasnet.parameters(), opt)
    # build dataloader
    logger.info('Building the dataloader of Conv-Tasnet')
    train_dataloader, val_dataloader = make_dataloader(opt)

    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))
    # build scheduler
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  factor=opt['scheduler']['factor'],
                                  patience=opt['scheduler']['patience'],
                                  verbose=True,
                                  min_lr=opt['scheduler']['min_lr'])

    # build trainer
    logger.info('Building the Trainer of Conv-Tasnet')
    trainer = trainer_Tasnet.Trainer(train_dataloader, val_dataloader,
                                     Conv_Tasnet, optimizer, scheduler, opt)
    trainer.run()
Esempio n. 9
0
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training Model')
    # configuration fiule
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt)

    set_logger.setup_logger(opt['logger']['name'],
                            opt['logger']['path'],
                            screen=opt['logger']['screen'],
                            tofile=opt['logger']['tofile'])
    logger = logging.getLogger(opt['logger']['name'])
    day_time = datetime.date.today().strftime('%y%m%d')

    # build model
    model = opt['model']['MODEL']
    logger.info("Building the model of {}".format(model))
    # Extraction and Suppression model
    if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction' or opt['model'][
            'MODEL'] == 'DPRNN_Speaker_Suppression':
        net = model_function.Extractin_Suppression_Model(
            **opt['Dual_Path_Aux_Speaker'])
    # Separation model
    if opt['model']['MODEL'] == 'DPRNN_Speech_Separation':
        net = model_function.Speech_Serapation_Model(
            **opt['Dual_Path_Aux_Speaker'])
    if opt['train']['gpuid']:
        if len(opt['train']['gpuid']) > 1:
            logger.info('We use GPUs : {}'.format(opt['train']['gpuid']))
        else:
            logger.info('We use GPUs : {}'.format(opt['train']['gpuid']))

        device = torch.device('cuda:{}'.format(opt['train']['gpuid'][0]))
        gpuids = opt['train']['gpuid']
        if len(gpuids) > 1:
            net = torch.nn.DataParallel(net, device_ids=gpuids)
        net = net.to(device)
    logger.info('Loading {} parameters: {:.3f} Mb'.format(
        model, check_parameters(net)))

    # build optimizer
    logger.info("Building the optimizer of {}".format(model))
    Optimizer = make_optimizer(net.parameters(), opt)

    Scheduler = ReduceLROnPlateau(Optimizer,
                                  mode='min',
                                  factor=opt['scheduler']['factor'],
                                  patience=opt['scheduler']['patience'],
                                  verbose=True,
                                  min_lr=opt['scheduler']['min_lr'])

    # build dataloader
    logger.info('Building the dataloader of {}'.format(model))
    train_dataloader, val_dataloader = make_dataloader(opt)
    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))

    # build trainer
    logger.info('............. Training ................')

    total_epoch = opt['train']['epoch']
    num_spks = opt['num_spks']
    print_freq = opt['logger']['print_freq']
    checkpoint_path = opt['train']['path']
    early_stop = opt['train']['early_stop']
    max_norm = opt['optim']['clip_norm']
    best_loss = np.inf
    no_improve = 0
    ce_loss = torch.nn.CrossEntropyLoss()
    weight = 0.1

    epoch = 0
    # Resume training settings
    if opt['resume']['state']:
        opt['resume']['path'] = opt['resume'][
            'path'] + '/' + '200722_epoch{}.pth.tar'.format(
                opt['resume']['epoch'])
        ckp = torch.load(opt['resume']['path'], map_location='cpu')
        epoch = ckp['epoch']
        logger.info("Resume from checkpoint {}: epoch {:.3f}".format(
            opt['resume']['path'], epoch))
        net.load_state_dict(ckp['model_state_dict'])
        net.to(device)
        Optimizer.load_state_dict(ckp['optim_state_dict'])

    while epoch < total_epoch:

        epoch += 1
        logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
            epoch, 0))
        num_steps = len(train_dataloader)

        # trainning process
        total_SNRloss = 0.0
        total_CEloss = 0.0
        num_index = 1
        start_time = time.time()
        for inputs, targets in train_dataloader:
            # Separation train
            if opt['model']['MODEL'] == 'DPRNN_Speech_Separation':
                mix = inputs
                ref = targets
                net.train()

                mix = mix.to(device)
                ref = [ref[i].to(device) for i in range(num_spks)]

                net.zero_grad()
                train_out = net(mix)
                SNR_loss = Loss(train_out, ref)
                loss = SNR_loss

            # Extraction train
            if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction':
                mix, aux = inputs
                ref, aux_len, sp_label = targets
                net.train()

                mix = mix.to(device)
                aux = aux.to(device)
                ref = ref.to(device)
                aux_len = aux_len.to(device)
                sp_label = sp_label.to(device)

                net.zero_grad()
                train_out = net([mix, aux, aux_len])
                SNR_loss = Loss_SI_SDR(train_out[0], ref)
                CE_loss = torch.mean(ce_loss(train_out[1], sp_label))
                loss = SNR_loss + weight * CE_loss
                total_CEloss += CE_loss.item()

            # Suppression train
            if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression':
                mix, aux = inputs
                ref, aux_len = targets
                net.train()

                mix = mix.to(device)
                aux = aux.to(device)
                ref = ref.to(device)
                aux_len = aux_len.to(device)

                net.zero_grad()
                train_out = net([mix, aux, aux_len])
                SNR_loss = Loss_SI_SDR(train_out[0], ref)
                loss = SNR_loss

            # BP processs
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm)
            Optimizer.step()

            total_SNRloss += SNR_loss.item()

            if num_index % print_freq == 0:
                message = '<Training epoch:{:d} / {:d} , iter:{:d} / {:d}, lr:{:.3e}, SI-SNR_loss:{:.3f}, CE loss:{:.3f}>'.format(
                    epoch, total_epoch, num_index, num_steps,
                    Optimizer.param_groups[0]['lr'], total_SNRloss / num_index,
                    total_CEloss / num_index)
                logger.info(message)

            num_index += 1

        end_time = time.time()
        mean_SNRLoss = total_SNRloss / num_index
        mean_CELoss = total_CEloss / num_index

        message = 'Finished Training *** <epoch:{:d} / {:d}, iter:{:d}, lr:{:.3e}, ' \
                  'SNR loss:{:.3f}, CE loss:{:.3f}, Total time:{:.3f} min> '.format(
            epoch, total_epoch, num_index, Optimizer.param_groups[0]['lr'], mean_SNRLoss, mean_CELoss, (end_time - start_time) / 60)
        logger.info(message)

        # development processs
        val_num_index = 1
        val_total_loss = 0.0
        val_CE_loss = 0.0
        val_acc_total = 0.0
        val_acc = 0.0
        val_start_time = time.time()
        val_num_steps = len(val_dataloader)
        for inputs, targets in val_dataloader:
            net.eval()
            with torch.no_grad():
                # Separation development
                if opt['model']['MODEL'] == 'DPRNN_Speech_Separation':
                    mix = inputs
                    ref = targets
                    mix = mix.to(device)
                    ref = [ref[i].to(device) for i in range(num_spks)]
                    Optimizer.zero_grad()
                    val_out = net(mix)
                    val_loss = Loss(val_out, ref)
                    val_total_loss += val_loss.item()

                # Extraction development
                if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction':
                    mix, aux = inputs
                    ref, aux_len, label = targets
                    mix = mix.to(device)
                    aux = aux.to(device)
                    ref = ref.to(device)
                    aux_len = aux_len.to(device)
                    label = label.to(device)
                    Optimizer.zero_grad()
                    val_out = net([mix, aux, aux_len])
                    val_loss = Loss_SI_SDR(val_out[0], ref)
                    val_ce = torch.mean(ce_loss(val_out[1], label))
                    val_acc = accuracy_speaker(val_out[1], label)
                    val_acc_total += val_acc
                    val_total_loss += val_loss.item()
                    val_CE_loss += val_ce.item()

                # suppression development
                if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression':
                    mix, aux = inputs
                    ref, aux_len = targets
                    mix = mix.to(device)
                    aux = aux.to(device)
                    ref = ref.to(device)
                    aux_len = aux_len.to(device)
                    Optimizer.zero_grad()
                    val_out = net([mix, aux, aux_len])
                    val_loss = Loss_SI_SDR(val_out[0], ref)
                    val_total_loss += val_loss.item()

                if val_num_index % print_freq == 0:
                    message = '<Valid-Epoch:{:d} / {:d}, iter:{:d} / {:d}, lr:{:.3e}, ' \
                              'val_SISNR_loss:{:.3f}, val_CE_loss:{:.3f}, val_acc :{:.3f}>' .format(
                        epoch, total_epoch, val_num_index, val_num_steps, Optimizer.param_groups[0]['lr'],
                        val_total_loss / val_num_index,
                        val_CE_loss / val_num_index,
                        val_acc_total / val_num_index)
                    logger.info(message)
            val_num_index += 1

        val_end_time = time.time()
        mean_val_total_loss = val_total_loss / val_num_index
        mean_val_CE_loss = val_CE_loss / val_num_index
        mean_acc = val_acc_total / val_num_index
        message = 'Finished *** <epoch:{:d}, iter:{:d}, lr:{:.3e}, val SI-SNR loss:{:.3f}, val_CE_loss:{:.3f}, val_acc:{:.3f}' \
                  ' Total time:{:.3f} min> '.format(epoch, val_num_index, Optimizer.param_groups[0]['lr'],
                                                    mean_val_total_loss, mean_val_CE_loss, mean_acc,
                                                    (val_end_time - val_start_time) / 60)
        logger.info(message)

        Scheduler.step(mean_val_total_loss)

        if mean_val_total_loss >= best_loss:
            no_improve += 1
            logger.info(
                'No improvement, Best SI-SNR Loss: {:.4f}'.format(best_loss))

        if mean_val_total_loss < best_loss:
            best_loss = mean_val_total_loss
            no_improve = 0
            save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time)
            logger.info(
                'Epoch: {:d}, Now Best SI-SNR Loss Change: {:.4f}'.format(
                    epoch, best_loss))

        if no_improve == early_stop:
            save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time)
            logger.info("Stop training cause no impr for {:d} epochs".format(
                no_improve))
            break