예제 #1
0
 def build_model(self):
     hps = self.hps
     ns = self.hps.ns
     emb_size = self.hps.emb_size
     self.Encoder = cc(Encoder(ns=ns, dp=hps.enc_dp))
     self.Decoder = cc(Decoder(ns=ns, c_a=hps.n_speakers,
                               emb_size=emb_size))
     self.Generator = cc(
         Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size))
     self.SpeakerClassifier = cc(
         SpeakerClassifier(ns=ns, n_class=hps.n_speakers, dp=hps.dis_dp))
     self.GoodClassifier = cc(
         SpeakerClassifier(ns=ns, n_class=hps.n_speakers, dp=hps.dis_dp))
     self.PatchDiscriminator = cc(
         nn.DataParallel(PatchDiscriminator(ns=ns, n_class=hps.n_speakers)))
     betas = (0.5, 0.9)
     params = list(self.Encoder.parameters()) + list(
         self.Decoder.parameters())
     self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas)
     self.clf_opt = optim.Adam(self.SpeakerClassifier.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.good_clf_opt = optim.Adam(self.SpeakerClassifier.parameters(),
                                    lr=self.hps.lr,
                                    betas=betas)
     self.gen_opt = optim.Adam(self.Generator.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(),
                                 lr=self.hps.lr,
                                 betas=betas)
예제 #2
0
    def __load_network(self):

        self.net = SpeakerClassifier(self.params, self.device)
        self.net.to(self.device)

        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.net = nn.DataParallel(self.net)
예제 #3
0
 def build_model(self):
     hps = self.hps
     self.SpeakerClassifier = SpeakerClassifier(ns=hps.ns, dp=hps.dp, n_class=hps.n_speakers)
     self.Encoder = Encoder(ns=hps.ns)
     if torch.cuda.is_available():
         self.SpeakerClassifier.cuda()
         self.Encoder.cuda()
     betas = (0.5, 0.9)
     self.opt = optim.Adam(self.SpeakerClassifier.parameters(), lr=self.hps.lr, betas=betas)
 def build_model(self, wavenet_mel):
     hps = self.hps
     ns = self.hps.ns
     emb_size = self.hps.emb_size
     c = 80 if wavenet_mel else 513
     patch_classify_kernel = (3, 4) if wavenet_mel else (17, 4)
     self.Encoder = cc(Encoder(c_in=c, ns=ns, dp=hps.enc_dp))
     self.Decoder = cc(
         Decoder(c_out=c, ns=ns, c_a=hps.n_speakers, emb_size=emb_size))
     self.Generator = cc(
         Decoder(c_out=c, ns=ns, c_a=hps.n_speakers, emb_size=emb_size))
     self.SpeakerClassifier = cc(
         SpeakerClassifier(ns=ns, n_class=hps.n_speakers, dp=hps.dis_dp))
     self.PatchDiscriminator = cc(
         nn.DataParallel(
             PatchDiscriminator(
                 ns=ns,
                 n_class=hps.n_speakers,
                 classify_kernel_size=patch_classify_kernel)))
     betas = (0.5, 0.9)
     params = list(self.Encoder.parameters()) + list(
         self.Decoder.parameters())
     self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas)
     self.clf_opt = optim.Adam(self.SpeakerClassifier.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.gen_opt = optim.Adam(self.Generator.parameters(),
                               lr=self.hps.lr,
                               betas=betas)
     self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(),
                                 lr=self.hps.lr,
                                 betas=betas)
예제 #5
0
    def build_model(self):
        self.Encoder = cc(Encoder())
        self.Decoder = [cc(Decoder()) for _ in range(4)]
        self.ACLayer = cc(ACLayer())
        self.Discriminator = cc(Discriminator())
        self.ASRLayer = cc(ASRLayer())
        self.SpeakerClassifier = cc(SpeakerClassifier())
        ac_betas = (0.5, 0.999)
        vae_betas = (0.9, 0.999)
        ac_lr = 0.00005
        vae_lr = 0.001
        dis_lr = 0.002
        cls_betas = (0.5, 0.999)
        asr_betas = (0.5, 0.999)
        cls_lr = 0.0002
        asr_lr = 0.00001

        self.list_decoder = []

        for i in range(4):
            self.list_decoder += list(self.Decoder[i].parameters())
        self.vae_params = list(self.Encoder.parameters()) + self.list_decoder
        self.ac_optimizer = optim.Adam(self.ACLayer.parameters(),
                                       lr=ac_lr,
                                       betas=ac_betas)
        self.vae_optimizer = optim.Adam(self.vae_params,
                                        lr=vae_lr,
                                        betas=vae_betas)
        self.dis_optimizer = optim.Adam(self.Discriminator.parameters(),
                                        lr=dis_lr,
                                        betas=ac_betas)

        self.asr_optimizer = optim.Adam(self.ASRLayer.parameters(),
                                        lr=asr_lr,
                                        betas=asr_betas)
        self.cls_optimizer = optim.Adam(self.SpeakerClassifier.parameters(),
                                        lr=cls_lr,
                                        betas=cls_betas)
예제 #6
0
class Trainer:
    def __init__(self, params, device):

        self.params = params
        self.device = device
        self.__load_network()
        self.__load_data()
        self.__load_optimizer()
        self.__load_criterion()
        self.__initialize_training_variables()

    def __load_previous_states(self):

        list_files = os.listdir(self.params.out_dir)
        list_files = [
            self.params.out_dir + '/' + f for f in list_files if '.chkpt' in f
        ]
        if list_files:
            file2load = max(list_files, key=os.path.getctime)
            checkpoint = torch.load(file2load, map_location=self.device)
            try:
                self.net.load_state_dict(checkpoint['model'])
            except RuntimeError:
                self.net.module.load_state_dict(checkpoint['model'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.params = checkpoint['settings']
            self.starting_epoch = checkpoint['epoch'] + 1
            self.step = checkpoint['step']
            print('Model "%s" is Loaded for requeue process' % file2load)
        else:
            self.step = 0
            self.starting_epoch = 1

    def __initialize_training_variables(self):

        if self.params.requeue:
            self.__load_previous_states()
        else:
            self.step = 0
            self.starting_epoch = 0

        self.best_EER = 50.0
        self.stopping = 0.0

    def __load_network(self):

        self.net = SpeakerClassifier(self.params, self.device)
        self.net.to(self.device)

        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.net = nn.DataParallel(self.net)

    def __load_data(self):
        print('Loading Data and Labels')
        with open(self.params.train_labels_path, 'r') as data_labels_file:
            train_labels = data_labels_file.readlines()

        data_loader_parameters = {
            'batch_size': self.params.batch_size,
            'shuffle': True,
            'num_workers': self.params.num_workers
        }
        self.training_generator = DataLoader(
            Dataset(train_labels, self.params), **data_loader_parameters)

    def __load_optimizer(self):
        if self.params.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.net.parameters(),
                                        lr=self.params.learning_rate,
                                        weight_decay=self.params.weight_decay)
        if self.params.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.net.parameters(),
                                       lr=self.params.learning_rate,
                                       weight_decay=self.params.weight_decay)
        if self.params.optimizer == 'RMSprop':
            self.optimizer = optim.RMSprop(
                self.net.parameters(),
                lr=self.params.learning_rate,
                weight_decay=self.params.weight_decay)

    def __update_optimizer(self):

        if self.params.optimizer == 'SGD' or self.params.optimizer == 'Adam':
            for paramGroup in self.optimizer.param_groups:
                paramGroup['lr'] *= 0.5
            print('New Learning Rate: {}'.format(paramGroup['lr']))

    def __load_criterion(self):
        self.criterion = nn.CrossEntropyLoss()

    def __initialize_batch_variables(self):

        self.print_time = time.time()
        self.train_loss = 0.0
        self.train_accuracy = 0.0
        self.train_batch = 0

    def __extractInputFromFeature(self, sline):

        features1 = normalizeFeatures(
            featureReader(self.params.valid_data_dir + '/' + sline[0] +
                          '.pickle'),
            normalization=self.params.normalization)
        features2 = normalizeFeatures(
            featureReader(self.params.valid_data_dir + '/' + sline[1] +
                          '.pickle'),
            normalization=self.params.normalization)

        input1 = torch.FloatTensor(features1).to(self.device)
        input2 = torch.FloatTensor(features2).to(self.device)

        return input1.unsqueeze(0), input2.unsqueeze(0)

    def __extract_scores(self, trials):

        scores = []
        for line in trials:
            sline = line[:-1].split()

            input1, input2 = self.__extractInputFromFeature(sline)

            if torch.cuda.device_count() > 1:
                emb1, emb2 = self.net.module.getEmbedding(
                    input1), self.net.module.getEmbedding(input2)
            else:
                emb1, emb2 = self.net.getEmbedding(
                    input1), self.net.getEmbedding(input2)

            dist = scoreCosineDistance(emb1, emb2)
            scores.append(dist.item())

        return scores

    def __calculate_EER(self, CL, IM):

        thresholds = np.arange(-1, 1, 0.01)
        FRR, FAR = np.zeros(len(thresholds)), np.zeros(len(thresholds))
        for idx, th in enumerate(thresholds):
            FRR[idx] = Score(CL, th, 'FRR')
            FAR[idx] = Score(IM, th, 'FAR')

        EER_Idx = np.argwhere(np.diff(np.sign(FAR - FRR)) != 0).reshape(-1)
        if len(EER_Idx) > 0:
            if len(EER_Idx) > 1:
                EER_Idx = EER_Idx[0]
            EER = round((FAR[int(EER_Idx)] + FRR[int(EER_Idx)]) / 2, 4)
        else:
            EER = 50.00
        return EER

    def __getAnnealedFactor(self):
        if torch.cuda.device_count() > 1:
            return self.net.module.predictionLayer.getAnnealedFactor(self.step)
        else:
            return self.net.predictionLayer.getAnnealedFactor(self.step)

    def __validate(self):

        with torch.no_grad():
            valid_time = time.time()
            self.net.eval()
            # EER Validation
            with open(params.valid_clients,
                      'r') as clients_in, open(params.valid_impostors,
                                               'r') as impostors_in:
                # score clients
                CL = self.__extract_scores(clients_in)
                IM = self.__extract_scores(impostors_in)
            # Compute EER
            EER = self.__calculate_EER(CL, IM)

            annealedFactor = self.__getAnnealedFactor()
            print('Annealed Factor is {}.'.format(annealedFactor))
            print(
                '--Validation Epoch:{epoch: d}, Updates:{Num_Batch: d}, EER:{eer: 3.3f}, elapse:{elapse: 3.3f} min'
                .format(epoch=self.epoch,
                        Num_Batch=self.step,
                        eer=EER,
                        elapse=(time.time() - valid_time) / 60))
            # early stopping and save the best model
            if EER < self.best_EER:
                self.best_EER = EER
                self.stopping = 0
                print('We found a better model!')
                chkptsave(params, self.net, self.optimizer, self.epoch,
                          self.step)
            else:
                self.stopping += 1
                print('Better Accuracy is: {}. {} epochs of no improvement'.
                      format(self.best_EER, self.stopping))
            self.print_time = time.time()
            self.net.train()

    def __update(self):

        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1

        if self.step % int(self.params.print_every) == 0:
            print(
                'Training Epoch:{epoch: d}, Updates:{Num_Batch: d} -----> xent:{xnet: .3f}, Accuracy:{acc: .2f}, elapse:{elapse: 3.3f} min'
                .format(epoch=self.epoch,
                        Num_Batch=self.step,
                        xnet=self.train_loss / self.train_batch,
                        acc=self.train_accuracy * 100 / self.train_batch,
                        elapse=(time.time() - self.print_time) / 60))
            self.__initialize_batch_variables()

        # validation
        if self.step % self.params.validate_every == 0:
            self.__validate()

    def __updateTrainningVariables(self):

        if (self.stopping + 1) % 15 == 0:
            self.__update_optimizer()

    def __randomSlice(self, inputTensor):
        index = random.randrange(200, self.params.window_size * 100)
        return inputTensor[:, :index, :]

    def train(self):

        print('Start Training')
        for self.epoch in range(self.starting_epoch, self.params.max_epochs
                                ):  # loop over the dataset multiple times
            self.net.train()
            self.__initialize_batch_variables()
            for input, label in self.training_generator:
                input, label = input.float().to(self.device), label.long().to(
                    self.device)
                input = self.__randomSlice(
                    input) if self.params.randomSlicing else input
                prediction, AMPrediction = self.net(input,
                                                    label=label,
                                                    step=self.step)
                loss = self.criterion(AMPrediction, label)
                loss.backward()
                self.train_accuracy += Accuracy(prediction, label)
                self.train_loss += loss.item()

                self.train_batch += 1
                if self.train_batch % self.params.gradientAccumulation == 0:
                    self.__update()

            if self.stopping > self.params.early_stopping:
                print('--Best Model EER%%: %.2f' % (self.best_EER))
                break

            self.__updateTrainningVariables()

        print('Finished Training')
예제 #7
0
class Classifier(object):
    def __init__(self, hps, data_loader, valid_data_loader, log_dir='./log/'):
        self.hps = hps
        self.data_loader = data_loader
        self.valid_data_loader = valid_data_loader
        self.model_kept = []
        self.max_keep = 10
        self.build_model()
        self.logger = Logger(log_dir)

    def build_model(self):
        hps = self.hps
        self.SpeakerClassifier = SpeakerClassifier(ns=hps.ns, dp=hps.dp, n_class=hps.n_speakers)
        self.Encoder = Encoder(ns=hps.ns)
        if torch.cuda.is_available():
            self.SpeakerClassifier.cuda()
            self.Encoder.cuda()
        betas = (0.5, 0.9)
        self.opt = optim.Adam(self.SpeakerClassifier.parameters(), lr=self.hps.lr, betas=betas)

    def load_encoder(self, model_path):
        print('load model from {}'.format(model_path))
        with open(model_path, 'rb') as f_in:
            all_model = torch.load(f_in)
            self.Encoder.load_state_dict(all_model['encoder'])

    def save_model(self, model_path, iteration):
        new_model_path = '{}-{}'.format(model_path, iteration)
        torch.save(self.SpeakerClassifier.state_dict(), new_model_path)
        self.model_kept.append(new_model_path)
        if len(self.model_kept) >= self.max_keep:
            os.remove(self.model_kept[0])
            self.model_kept.pop(0)

    def load_model(self, model_path):
        print('load model from {}'.format(model_path))
        self.SpeakerClassifier.load_state_dict(torch.load(model_path))

    def set_eval(self):
        self.SpeakerClassifier.eval()

    def set_train(self):
        self.SpeakerClassifier.train()

    def permute_data(self, data):
        C = to_var(data[0], requires_grad=False)
        X = to_var(data[2]).permute(0, 2, 1)
        return C, X

    def encode_step(self, x):
        enc = self.Encoder(x)
        return enc

    def forward_step(self, enc):
        logits = self.SpeakerClassifier(enc)
        return logits

    def cal_loss(self, logits, y_true):
        # calculate loss 
        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, y_true)
        return loss

    def valid(self, n_batches=10):
        # input: valid data, output: (loss, acc)
        total_loss, total_acc = 0., 0.
        self.set_eval()
        for i in range(n_batches):
            data = next(self.valid_data_loader)
            y, x = self.permute_data(data)
            enc = self.Encoder(x)
            logits = self.SpeakerClassifier(enc)
            loss = self.cal_loss(logits, y)
            acc = cal_acc(logits, y)
            total_loss += loss.data[0]
            total_acc += acc  
        self.set_train()
        return total_loss / n_batches, total_acc / n_batches

    def train(self, model_path, flag='train'):
        # load hyperparams
        hps = self.hps
        for iteration in range(hps.iters):
            data = next(self.data_loader)
            y, x = self.permute_data(data)
            # encode
            enc = self.encode_step(x)
            # forward to classifier
            logits = self.forward_step(enc)
            # calculate loss
            loss = self.cal_loss(logits, y)
            # optimize
            reset_grad([self.SpeakerClassifier])
            loss.backward()
            grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
            self.opt.step()
            # calculate acc
            acc = cal_acc(logits, y)
            # print info
            info = {
                f'{flag}/loss': loss.data[0], 
                f'{flag}/acc': acc,
            }
            slot_value = (iteration + 1, hps.iters) + tuple([value for value in info.values()])
            log = 'iter:[%06d/%06d], loss=%.3f, acc=%.3f'
            print(log % slot_value, end='\r')
            for tag, value in info.items():
                self.logger.scalar_summary(tag, value, iteration)
            if iteration % 1000 == 0 or iteration + 1 == hps.iters:
                valid_loss, valid_acc = self.valid(n_batches=10)
                # print info
                info = {
                    f'{flag}/valid_loss': valid_loss, 
                    f'{flag}/valid_acc': valid_acc,
                }
                slot_value = (iteration + 1, hps.iters) + \
                        tuple([value for value in info.values()])
                log = 'iter:[%06d/%06d], valid_loss=%.3f, valid_acc=%.3f'
                print(log % slot_value)
                for tag, value in info.items():
                    self.logger.scalar_summary(tag, value, iteration)
                self.save_model(model_path, iteration)