Beispiel #1
0
    def validate(self, net, cv_loader, criterion, feeder):
        accu_cv_loss = 0.
        accu_n_frames = 0

        if len(self.gpu_ids) > 1:
            net = net.module

        net.eval()
        for k, egs in enumerate(cv_loader):
            mix = egs['mix']
            sph = egs['sph']
            n_samples = egs['n_samples']

            mix = mix.to(self.device)
            sph = sph.to(self.device)
            n_samples = n_samples.to(self.device)

            n_frames = countFrames(n_samples, self.win_size, self.hop_size)

            feat, lbl = feeder(mix, sph)

            with torch.no_grad():
                loss_mask = lossMask(shape=lbl.shape,
                                     n_frames=n_frames,
                                     device=self.device)
                est = net(feat)
                loss = criterion(est, lbl, loss_mask, n_frames)

            accu_cv_loss += loss.data.item() * sum(n_frames)
            accu_n_frames += sum(n_frames)

        avg_cv_loss = accu_cv_loss / accu_n_frames
        return avg_cv_loss
Beispiel #2
0
    def train(self, args):
        with open(args.tr_list, 'r') as f:
            self.tr_list = [line.strip() for line in f.readlines()]
        self.tr_size = len(self.tr_list)
        self.cv_file = args.cv_file
        self.ckpt_dir = args.ckpt_dir
        self.logging_period = args.logging_period
        self.resume_model = args.resume_model
        self.time_log = args.time_log
        self.lr = args.lr
        self.lr_decay_factor = args.lr_decay_factor
        self.lr_decay_period = args.lr_decay_period
        self.clip_norm = args.clip_norm
        self.max_n_epochs = args.max_n_epochs
        self.batch_size = args.batch_size
        self.buffer_size = args.buffer_size
        self.loss_log = args.loss_log
        self.unit = args.unit
        self.segment_size = args.segment_size
        self.segment_shift = args.segment_shift

        self.gpu_ids = tuple(map(int, args.gpu_ids.split(',')))
        if len(self.gpu_ids) == 1 and self.gpu_ids[0] == -1:
            # cpu only
            self.device = torch.device('cpu')
        else:
            # gpu
            self.device = torch.device('cuda:{}'.format(self.gpu_ids[0]))

        if not os.path.isdir(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)

        logger = getLogger(os.path.join(self.ckpt_dir, 'train.log'),
                           log_file=True)

        # create data loaders for training and cross validation
        tr_loader = AudioLoader(self.tr_list,
                                self.sample_rate,
                                self.unit,
                                self.segment_size,
                                self.segment_shift,
                                self.batch_size,
                                self.buffer_size,
                                self.in_norm,
                                mode='train')
        cv_loader = AudioLoader(self.cv_file,
                                self.sample_rate,
                                unit='utt',
                                segment_size=None,
                                segment_shift=None,
                                batch_size=1,
                                buffer_size=10,
                                in_norm=self.in_norm,
                                mode='eval')

        # create a network
        net = Net()
        logger.info('Model summary:\n{}'.format(net))

        net = net.to(self.device)
        if len(self.gpu_ids) > 1:
            net = DataParallel(net, device_ids=self.gpu_ids)

        # calculate model size
        param_count = numParams(net)
        logger.info('Trainable parameter count: {:,d} -> {:.2f} MB\n'.format(
            param_count, param_count * 32 / 8 / (2**20)))

        # net feeder
        feeder = NetFeeder(self.device, self.win_size, self.hop_size)

        # training criterion and optimizer
        criterion = LossFunction()
        optimizer = Adam(net.parameters(), lr=self.lr, amsgrad=False)
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=self.lr_decay_period,
                                        gamma=self.lr_decay_factor)

        # resume model if needed
        if self.resume_model:
            logger.info('Resuming model from {}'.format(self.resume_model))
            ckpt = CheckPoint()
            ckpt.load(self.resume_model, self.device)
            state_dict = {}
            for key in ckpt.net_state_dict:
                if len(self.gpu_ids) > 1:
                    state_dict['module.' + key] = ckpt.net_state_dict[key]
                else:
                    state_dict[key] = ckpt.net_state_dict[key]
            net.load_state_dict(state_dict)
            optimizer.load_state_dict(ckpt.optim_state_dict)
            ckpt_info = ckpt.ckpt_info
            logger.info(
                'model info: epoch {}, iter {}, cv_loss - {:.4f}\n'.format(
                    ckpt.ckpt_info['cur_epoch'] + 1,
                    ckpt.ckpt_info['cur_iter'] + 1, ckpt.ckpt_info['cv_loss']))
        else:
            logger.info('Training from scratch...\n')
            ckpt_info = {
                'cur_epoch': 0,
                'cur_iter': 0,
                'tr_loss': None,
                'cv_loss': None,
                'best_loss': float('inf')
            }

        start_iter = 0
        # train model
        while ckpt_info['cur_epoch'] < self.max_n_epochs:
            accu_tr_loss = 0.
            accu_n_frames = 0
            net.train()
            for n_iter, egs in enumerate(tr_loader):
                n_iter += start_iter
                mix = egs['mix']
                sph = egs['sph']
                n_samples = egs['n_samples']

                mix = mix.to(self.device)
                sph = sph.to(self.device)
                n_samples = n_samples.to(self.device)

                n_frames = countFrames(n_samples, self.win_size, self.hop_size)

                start_time = timeit.default_timer()

                # prepare features and labels
                feat, lbl = feeder(mix, sph)
                loss_mask = lossMask(shape=lbl.shape,
                                     n_frames=n_frames,
                                     device=self.device)
                # forward + backward + optimize
                optimizer.zero_grad()
                with torch.enable_grad():
                    est = net(feat)
                loss = criterion(est, lbl, loss_mask, n_frames)
                loss.backward()
                if self.clip_norm >= 0.0:
                    clip_grad_norm_(net.parameters(), self.clip_norm)
                optimizer.step()
                # calculate loss
                running_loss = loss.data.item()
                accu_tr_loss += running_loss * sum(n_frames)
                accu_n_frames += sum(n_frames)

                end_time = timeit.default_timer()
                batch_time = end_time - start_time

                if self.time_log:
                    with open(self.time_log, 'a+') as f:
                        print(
                            'Epoch [{}/{}], Iter [{}], tr_loss = {:.4f} / {:.4f}, batch_time (s) = {:.4f}'
                            .format(ckpt_info['cur_epoch'] + 1,
                                    self.max_n_epochs, n_iter, running_loss,
                                    accu_tr_loss / accu_n_frames, batch_time),
                            file=f)
                        f.flush()
                else:
                    print(
                        'Epoch [{}/{}], Iter [{}], tr_loss = {:.4f} / {:.4f}, batch_time (s) = {:.4f}'
                        .format(ckpt_info['cur_epoch'] + 1, self.max_n_epochs,
                                n_iter, running_loss,
                                accu_tr_loss / accu_n_frames, batch_time),
                        flush=True)

                if (n_iter + 1) % self.logging_period == 0:
                    avg_tr_loss = accu_tr_loss / accu_n_frames
                    avg_cv_loss = self.validate(net, cv_loader, criterion,
                                                feeder)
                    net.train()

                    ckpt_info['cur_iter'] = n_iter
                    is_best = True if avg_cv_loss < ckpt_info[
                        'best_loss'] else False
                    ckpt_info[
                        'best_loss'] = avg_cv_loss if is_best else ckpt_info[
                            'best_loss']
                    latest_model = 'latest.pt'
                    best_model = 'best.pt'
                    ckpt_info['tr_loss'] = avg_tr_loss
                    ckpt_info['cv_loss'] = avg_cv_loss
                    if len(self.gpu_ids) > 1:
                        ckpt = CheckPoint(ckpt_info, net.module.state_dict(),
                                          optimizer.state_dict())
                    else:
                        ckpt = CheckPoint(ckpt_info, net.state_dict(),
                                          optimizer.state_dict())
                    logger.info('Saving checkpoint into {}'.format(
                        os.path.join(self.ckpt_dir, latest_model)))
                    if is_best:
                        logger.info('Saving checkpoint into {}'.format(
                            os.path.join(self.ckpt_dir, best_model)))
                    logger.info(
                        'Epoch [{}/{}], ( tr_loss: {:.4f} | cv_loss: {:.4f} )\n'
                        .format(ckpt_info['cur_epoch'] + 1, self.max_n_epochs,
                                avg_tr_loss, avg_cv_loss))

                    model_path = os.path.join(self.ckpt_dir, 'models')
                    if not os.path.isdir(model_path):
                        os.makedirs(model_path)

                    ckpt.save(os.path.join(model_path, latest_model), is_best,
                              os.path.join(model_path, best_model))

                    lossLog(os.path.join(self.ckpt_dir, self.loss_log), ckpt,
                            self.logging_period)

                    accu_tr_loss = 0.
                    accu_n_frames = 0

                    if n_iter + 1 == self.tr_size // self.batch_size:
                        start_iter = 0
                        ckpt_info['cur_iter'] = 0
                        break

            ckpt_info['cur_epoch'] += 1
            scheduler.step()  # learning rate decay

        return
Beispiel #3
0
    def test(self, args):
        with open(args.tt_list, 'r') as f:
            self.tt_list = [line.strip() for line in f.readlines()]
        self.model_file = args.model_file
        self.ckpt_dir = args.ckpt_dir
        self.est_path = args.est_path
        self.write_ideal = args.write_ideal
        self.gpu_ids = tuple(map(int, args.gpu_ids.split(',')))
        if len(self.gpu_ids) == 1 and self.gpu_ids[0] == -1:
            # cpu only
            self.device = torch.device('cpu')
        else:
            # gpu
            self.device = torch.device('cuda:{}'.format(self.gpu_ids[0]))

        if not os.path.isdir(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)
        logger = getLogger(os.path.join(self.ckpt_dir, 'test.log'),
                           log_file=True)

        # create a network
        net = Net()
        logger.info('Model summary:\n{}'.format(net))

        net = net.to(self.device)

        # calculate model size
        param_count = numParams(net)
        logger.info('Trainable parameter count: {:,d} -> {:.2f} MB\n'.format(
            param_count, param_count * 32 / 8 / (2**20)))

        # training criterion and optimizer
        criterion = LossFunction()

        # net feeder
        feeder = NetFeeder(self.device, self.win_size, self.hop_size)

        # resynthesizer
        resynthesizer = Resynthesizer(self.device, self.win_size,
                                      self.hop_size)

        # load model
        logger.info('Loading model from {}'.format(self.model_file))
        ckpt = CheckPoint()
        ckpt.load(self.model_file, self.device)
        net.load_state_dict(ckpt.net_state_dict)
        logger.info('model info: epoch {}, iter {}, cv_loss - {:.4f}\n'.format(
            ckpt.ckpt_info['cur_epoch'] + 1, ckpt.ckpt_info['cur_iter'] + 1,
            ckpt.ckpt_info['cv_loss']))

        net.eval()
        for i in range(len(self.tt_list)):
            # create a data loader for testing
            tt_loader = AudioLoader(self.tt_list[i],
                                    self.sample_rate,
                                    unit='utt',
                                    segment_size=None,
                                    segment_shift=None,
                                    batch_size=1,
                                    buffer_size=10,
                                    in_norm=self.in_norm,
                                    mode='eval')
            logger.info('[{}/{}] Estimating on {}'.format(
                i + 1, len(self.tt_list), self.tt_list[i]))

            est_subdir = os.path.join(
                self.est_path,
                self.tt_list[i].split('/')[-1].replace('.ex', ''))
            if not os.path.isdir(est_subdir):
                os.makedirs(est_subdir)

            accu_tt_loss = 0.
            accu_n_frames = 0
            for k, egs in enumerate(tt_loader):
                mix = egs['mix']
                sph = egs['sph']
                n_samples = egs['n_samples']

                n_frames = countFrames(n_samples, self.win_size, self.hop_size)

                mix = mix.to(self.device)
                sph = sph.to(self.device)

                feat, lbl = feeder(mix, sph)

                with torch.no_grad():
                    loss_mask = lossMask(shape=lbl.shape,
                                         n_frames=n_frames,
                                         device=self.device)
                    est = net(feat)
                    loss = criterion(est, lbl, loss_mask, n_frames)

                accu_tt_loss += loss.data.item() * sum(n_frames)
                accu_n_frames += sum(n_frames)

                sph_idl = resynthesizer(lbl, mix)
                sph_est = resynthesizer(est, mix)

                # save estimates
                mix = mix[0].cpu().numpy()
                sph = sph[0].cpu().numpy()
                sph_est = sph_est[0].cpu().numpy()
                sph_idl = sph_idl[0].cpu().numpy()
                mix, sph, sph_est, sph_idl = wavNormalize(
                    mix, sph, sph_est, sph_idl)
                sf.write(os.path.join(est_subdir, '{}_mix.wav'.format(k)), mix,
                         self.sample_rate)
                sf.write(os.path.join(est_subdir, '{}_sph.wav'.format(k)), sph,
                         self.sample_rate)
                sf.write(os.path.join(est_subdir, '{}_sph_est.wav'.format(k)),
                         sph_est, self.sample_rate)
                if self.write_ideal:
                    sf.write(
                        os.path.join(est_subdir, '{}_sph_idl.wav'.format(k)),
                        sph_idl, self.sample_rate)

            avg_tt_loss = accu_tt_loss / accu_n_frames
            logger.info('loss: {:.4f}'.format(avg_tt_loss))

        return