Exemple #1
0
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training Deep Clustering')
    parser.add_argument('--opt',
                        type=str,
                        help='Path to option YAML file.',
                        default='./config/train.yml')
    args = parser.parse_args()
    opt = option.parse(args.opt)
    set_logger.setup_logger(opt['logger']['name'],
                            opt['logger']['path'],
                            screen=opt['logger']['screen'],
                            tofile=opt['logger']['tofile'])
    logger = logging.getLogger(opt['logger']['name'])

    logger.info("Building the model of Deep Clustering")
    dpcl = model.DPCL(**opt['DPCL'])

    logger.info("Building the optimizer of Deep Clustering")
    optimizer = make_optimizer(dpcl.parameters(), opt)

    logger.info('Building the dataloader of Deep Clustering')
    train_dataloader, val_dataloader = make_dataloader(opt)

    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))
    logger.info('Building the Trainer of Deep Clustering')
    trainer = Trainer(train_dataloader, val_dataloader, dpcl, optimizer, opt)
    trainer.run()
Exemple #2
0
 def __init__(self, mix_path, yaml_path, model, gpuid):
     super(Separation, self).__init__()
     self.mix = read_wav(mix_path)
     opt = parse(yaml_path, is_tain=False)
     net = ConvTasNet(**opt['Conv_Tasnet'])
     dicts = torch.load(model, map_location='cpu')
     net.load_state_dict(dicts["model_state_dict"])
     self.logger = get_logger(__name__)
     self.logger.info('Load checkpoint from {}, epoch {: d}'.format(
         model, dicts["epoch"]))
     self.net = net.cuda()
     self.device = torch.device(
         'cuda:{}'.format(gpuid[0]) if len(gpuid) > 0 else 'cpu')
     self.gpuid = tuple(gpuid)
Exemple #3
0
    def __init__(self, mix_json, aux_json, ref_json, yaml_path, model, gpuid):
        super(Separation, self).__init__()
        opt = parse(yaml_path)
        self.data_loader = Datasets_Test(
            mix_json,
            aux_json,
            ref_json,
            sample_rate=opt['datasets']['audio_setting']['sample_rate'],
            model=opt['model']['MODEL'])
        self.mix_infos = self.data_loader.mix
        self.total_wavs = len(self.mix_infos)
        self.model = opt['model']['MODEL']
        # Extraction and Suppression model
        if self.model == 'DPRNN_Speaker_Extraction' or self.model == 'DPRNN_Speaker_Suppression':
            net = model_function.Extractin_Suppression_Model(
                **opt['Dual_Path_Aux_Speaker'])
        # Separation model
        if self.model == 'DPRNN_Speech_Separation':
            net = model_function.Speech_Serapation_Model(
                **opt['Dual_Path_Aux_Speaker'])
        net = torch.nn.DataParallel(net)
        dicts = torch.load(model, map_location='cpu')
        net.load_state_dict(dicts["model_state_dict"])
        setup_logger(opt['logger']['name'],
                     opt['logger']['path'],
                     screen=opt['logger']['screen'],
                     tofile=opt['logger']['tofile'])
        self.logger = logging.getLogger(opt['logger']['name'])
        self.logger.info('Load checkpoint from {}, epoch {: d}'.format(
            model, dicts["epoch"]))
        self.net = net.cuda()

        self.device = torch.device(
            'cuda:{}'.format(gpuid[0]) if len(gpuid) > 0 else 'cpu')
        self.gpuid = tuple(gpuid)

        self.Output_wav = False
Exemple #4
0
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training Conv-TasNet')
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt)
    set_logger.setup_logger(opt['logger']['name'],
                            opt['logger']['path'],
                            screen=opt['logger']['screen'],
                            tofile=opt['logger']['tofile'])
    logger = logging.getLogger(opt['logger']['name'])
    # build model
    logger.info("Building the model of Conv-Tasnet")
    Conv_Tasnet = model.Conv_TasNet(**opt['Conv_Tasnet'])
    # build optimizer
    logger.info("Building the optimizer of Conv-Tasnet")
    optimizer = make_optimizer(Conv_Tasnet.parameters(), opt)
    # build dataloader
    logger.info('Building the dataloader of Conv-Tasnet')
    train_dataloader, val_dataloader = make_dataloader(opt)

    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))
    # build scheduler
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  factor=opt['scheduler']['factor'],
                                  patience=opt['scheduler']['patience'],
                                  verbose=True,
                                  min_lr=opt['scheduler']['min_lr'])

    # build trainer
    logger.info('Building the Trainer of Conv-Tasnet')
    trainer = trainer_Tasnet.Trainer(train_dataloader, val_dataloader,
                                     Conv_Tasnet, optimizer, scheduler, opt)
    trainer.run()
Exemple #5
0
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training Model')
    # configuration fiule
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt)

    set_logger.setup_logger(opt['logger']['name'],
                            opt['logger']['path'],
                            screen=opt['logger']['screen'],
                            tofile=opt['logger']['tofile'])
    logger = logging.getLogger(opt['logger']['name'])
    day_time = datetime.date.today().strftime('%y%m%d')

    # build model
    model = opt['model']['MODEL']
    logger.info("Building the model of {}".format(model))
    # Extraction and Suppression model
    if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction' or opt['model'][
            'MODEL'] == 'DPRNN_Speaker_Suppression':
        net = model_function.Extractin_Suppression_Model(
            **opt['Dual_Path_Aux_Speaker'])
    # Separation model
    if opt['model']['MODEL'] == 'DPRNN_Speech_Separation':
        net = model_function.Speech_Serapation_Model(
            **opt['Dual_Path_Aux_Speaker'])
    if opt['train']['gpuid']:
        if len(opt['train']['gpuid']) > 1:
            logger.info('We use GPUs : {}'.format(opt['train']['gpuid']))
        else:
            logger.info('We use GPUs : {}'.format(opt['train']['gpuid']))

        device = torch.device('cuda:{}'.format(opt['train']['gpuid'][0]))
        gpuids = opt['train']['gpuid']
        if len(gpuids) > 1:
            net = torch.nn.DataParallel(net, device_ids=gpuids)
        net = net.to(device)
    logger.info('Loading {} parameters: {:.3f} Mb'.format(
        model, check_parameters(net)))

    # build optimizer
    logger.info("Building the optimizer of {}".format(model))
    Optimizer = make_optimizer(net.parameters(), opt)

    Scheduler = ReduceLROnPlateau(Optimizer,
                                  mode='min',
                                  factor=opt['scheduler']['factor'],
                                  patience=opt['scheduler']['patience'],
                                  verbose=True,
                                  min_lr=opt['scheduler']['min_lr'])

    # build dataloader
    logger.info('Building the dataloader of {}'.format(model))
    train_dataloader, val_dataloader = make_dataloader(opt)
    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))

    # build trainer
    logger.info('............. Training ................')

    total_epoch = opt['train']['epoch']
    num_spks = opt['num_spks']
    print_freq = opt['logger']['print_freq']
    checkpoint_path = opt['train']['path']
    early_stop = opt['train']['early_stop']
    max_norm = opt['optim']['clip_norm']
    best_loss = np.inf
    no_improve = 0
    ce_loss = torch.nn.CrossEntropyLoss()
    weight = 0.1

    epoch = 0
    # Resume training settings
    if opt['resume']['state']:
        opt['resume']['path'] = opt['resume'][
            'path'] + '/' + '200722_epoch{}.pth.tar'.format(
                opt['resume']['epoch'])
        ckp = torch.load(opt['resume']['path'], map_location='cpu')
        epoch = ckp['epoch']
        logger.info("Resume from checkpoint {}: epoch {:.3f}".format(
            opt['resume']['path'], epoch))
        net.load_state_dict(ckp['model_state_dict'])
        net.to(device)
        Optimizer.load_state_dict(ckp['optim_state_dict'])

    while epoch < total_epoch:

        epoch += 1
        logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
            epoch, 0))
        num_steps = len(train_dataloader)

        # trainning process
        total_SNRloss = 0.0
        total_CEloss = 0.0
        num_index = 1
        start_time = time.time()
        for inputs, targets in train_dataloader:
            # Separation train
            if opt['model']['MODEL'] == 'DPRNN_Speech_Separation':
                mix = inputs
                ref = targets
                net.train()

                mix = mix.to(device)
                ref = [ref[i].to(device) for i in range(num_spks)]

                net.zero_grad()
                train_out = net(mix)
                SNR_loss = Loss(train_out, ref)
                loss = SNR_loss

            # Extraction train
            if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction':
                mix, aux = inputs
                ref, aux_len, sp_label = targets
                net.train()

                mix = mix.to(device)
                aux = aux.to(device)
                ref = ref.to(device)
                aux_len = aux_len.to(device)
                sp_label = sp_label.to(device)

                net.zero_grad()
                train_out = net([mix, aux, aux_len])
                SNR_loss = Loss_SI_SDR(train_out[0], ref)
                CE_loss = torch.mean(ce_loss(train_out[1], sp_label))
                loss = SNR_loss + weight * CE_loss
                total_CEloss += CE_loss.item()

            # Suppression train
            if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression':
                mix, aux = inputs
                ref, aux_len = targets
                net.train()

                mix = mix.to(device)
                aux = aux.to(device)
                ref = ref.to(device)
                aux_len = aux_len.to(device)

                net.zero_grad()
                train_out = net([mix, aux, aux_len])
                SNR_loss = Loss_SI_SDR(train_out[0], ref)
                loss = SNR_loss

            # BP processs
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm)
            Optimizer.step()

            total_SNRloss += SNR_loss.item()

            if num_index % print_freq == 0:
                message = '<Training epoch:{:d} / {:d} , iter:{:d} / {:d}, lr:{:.3e}, SI-SNR_loss:{:.3f}, CE loss:{:.3f}>'.format(
                    epoch, total_epoch, num_index, num_steps,
                    Optimizer.param_groups[0]['lr'], total_SNRloss / num_index,
                    total_CEloss / num_index)
                logger.info(message)

            num_index += 1

        end_time = time.time()
        mean_SNRLoss = total_SNRloss / num_index
        mean_CELoss = total_CEloss / num_index

        message = 'Finished Training *** <epoch:{:d} / {:d}, iter:{:d}, lr:{:.3e}, ' \
                  'SNR loss:{:.3f}, CE loss:{:.3f}, Total time:{:.3f} min> '.format(
            epoch, total_epoch, num_index, Optimizer.param_groups[0]['lr'], mean_SNRLoss, mean_CELoss, (end_time - start_time) / 60)
        logger.info(message)

        # development processs
        val_num_index = 1
        val_total_loss = 0.0
        val_CE_loss = 0.0
        val_acc_total = 0.0
        val_acc = 0.0
        val_start_time = time.time()
        val_num_steps = len(val_dataloader)
        for inputs, targets in val_dataloader:
            net.eval()
            with torch.no_grad():
                # Separation development
                if opt['model']['MODEL'] == 'DPRNN_Speech_Separation':
                    mix = inputs
                    ref = targets
                    mix = mix.to(device)
                    ref = [ref[i].to(device) for i in range(num_spks)]
                    Optimizer.zero_grad()
                    val_out = net(mix)
                    val_loss = Loss(val_out, ref)
                    val_total_loss += val_loss.item()

                # Extraction development
                if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction':
                    mix, aux = inputs
                    ref, aux_len, label = targets
                    mix = mix.to(device)
                    aux = aux.to(device)
                    ref = ref.to(device)
                    aux_len = aux_len.to(device)
                    label = label.to(device)
                    Optimizer.zero_grad()
                    val_out = net([mix, aux, aux_len])
                    val_loss = Loss_SI_SDR(val_out[0], ref)
                    val_ce = torch.mean(ce_loss(val_out[1], label))
                    val_acc = accuracy_speaker(val_out[1], label)
                    val_acc_total += val_acc
                    val_total_loss += val_loss.item()
                    val_CE_loss += val_ce.item()

                # suppression development
                if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression':
                    mix, aux = inputs
                    ref, aux_len = targets
                    mix = mix.to(device)
                    aux = aux.to(device)
                    ref = ref.to(device)
                    aux_len = aux_len.to(device)
                    Optimizer.zero_grad()
                    val_out = net([mix, aux, aux_len])
                    val_loss = Loss_SI_SDR(val_out[0], ref)
                    val_total_loss += val_loss.item()

                if val_num_index % print_freq == 0:
                    message = '<Valid-Epoch:{:d} / {:d}, iter:{:d} / {:d}, lr:{:.3e}, ' \
                              'val_SISNR_loss:{:.3f}, val_CE_loss:{:.3f}, val_acc :{:.3f}>' .format(
                        epoch, total_epoch, val_num_index, val_num_steps, Optimizer.param_groups[0]['lr'],
                        val_total_loss / val_num_index,
                        val_CE_loss / val_num_index,
                        val_acc_total / val_num_index)
                    logger.info(message)
            val_num_index += 1

        val_end_time = time.time()
        mean_val_total_loss = val_total_loss / val_num_index
        mean_val_CE_loss = val_CE_loss / val_num_index
        mean_acc = val_acc_total / val_num_index
        message = 'Finished *** <epoch:{:d}, iter:{:d}, lr:{:.3e}, val SI-SNR loss:{:.3f}, val_CE_loss:{:.3f}, val_acc:{:.3f}' \
                  ' Total time:{:.3f} min> '.format(epoch, val_num_index, Optimizer.param_groups[0]['lr'],
                                                    mean_val_total_loss, mean_val_CE_loss, mean_acc,
                                                    (val_end_time - val_start_time) / 60)
        logger.info(message)

        Scheduler.step(mean_val_total_loss)

        if mean_val_total_loss >= best_loss:
            no_improve += 1
            logger.info(
                'No improvement, Best SI-SNR Loss: {:.4f}'.format(best_loss))

        if mean_val_total_loss < best_loss:
            best_loss = mean_val_total_loss
            no_improve = 0
            save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time)
            logger.info(
                'Epoch: {:d}, Now Best SI-SNR Loss Change: {:.4f}'.format(
                    epoch, best_loss))

        if no_improve == early_stop:
            save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time)
            logger.info("Stop training cause no impr for {:d} epochs".format(
                no_improve))
            break
        # B x TF x nspk
        y = ibm * non_silent.expand_as(ibm)
        # attractors are the weighted average of the embeddings
        # calculated by V and Y
        # B x K x nspk
        v_y = torch.bmm(torch.transpose(v, 1, 2), y)
        # B x K x nspk
        sum_y = torch.sum(y, 1, keepdim=True).expand_as(v_y)
        # B x K x nspk
        attractor = v_y / (sum_y + self.eps)

        # calculate the distance bewteen embeddings and attractors
        # and generate the masks
        # B x TF x nspk
        dist = v.bmm(attractor)
        # B x TF x nspk
        mask = Fun.softmax(dist, dim=2)
        return mask, hidden

    def init_hidden(self, batch_size):
        return self.rnn.init_hidden(batch_size)


if __name__ == "__main__":
    opt = option.parse('../config/train.yml')
    net = DANet(**opt['DANet'])
    input = torch.randn(5, 10, 129)
    ibm = torch.randint(2, (5, 1290, 2))
    non_silent = torch.randn(5, 1290, 1)
    out = net(input, ibm, non_silent)