Esempio n. 1
0
            iteration += 1

            if (batch_id + 1) % hp.train.log_interval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(time.ctime(), e+1,
                        batch_id+1, len(train_dataset)//hp.train.N, iteration,loss, total_loss / (batch_id + 1))
                print(mesg)
                if hp.train.log_file is not None:
                    with open(hp.train.log_file,'a') as f:
                        f.write(mesg)
                    
        if hp.train.checkpoint_dir is not None and (e + 1) % hp.train.checkpoint_interval == 0:
            embedder_net.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e+1) + "_batch_id_" + str(batch_id+1) + ".pth"
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir, ckpt_model_filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            embedder_net.to(device).train()
    
    ###save model
    embedder_net.eval().cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(batch_id + 1) + ".model"
    save_model_path = os.path.join(hp.train.checkpoint_dir, save_model_filename)
    torch.save(embedder_net.state_dict(), save_model_path)
    
    print("\nDone, trained model saved at", save_model_path)
    with torch.no_grad():
        avg_EER = 0
        i = 0
        for e in range(hp.test.epochs):
            batch_avg_EER = 0
            for batch_id, mel_db_batch in enumerate(test_loader):
def train(model_path):
    device = torch.device(hp.device)

    train_dataset = SpeakerDatasetTIMITPreprocessed()
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)

    embedder_net = SpeechEmbedder().to(device)
    if hp.train.restore:
        embedder_net.load_state_dict(torch.load(model_path))
    ge2e_loss = GE2ELoss(device)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([{
        'params': embedder_net.parameters()
    }, {
        'params': ge2e_loss.parameters()
    }],
                                lr=hp.train.lr)

    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)

    embedder_net.train()
    iteration = 0
    for e in range(hp.train.epochs):
        total_loss = 0
        for batch_id, mel_db_batch in enumerate(train_loader):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.train.N * hp.train.M),
                                 hp.train.N * hp.train.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.train.N, hp.train.M, embeddings.size(1)))

            #get loss, call backward, step optimizer
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()

            total_loss = total_loss + loss
            iteration += 1
            if (batch_id + 1) % hp.train.log_interval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(
                    time.ctime(), e + 1, batch_id + 1,
                    len(train_dataset) // hp.train.N, iteration, loss,
                    total_loss / (batch_id + 1))
                print(mesg)
                if hp.train.log_file is not None:
                    '''
                    if os.path.exists(hp.train.log_file):
                        os.mknod(hp.train.log_file)
                    '''
                    with open(hp.train.log_file, 'w') as f:
                        f.write(mesg)

        if hp.train.checkpoint_dir is not None and (
                e + 1) % hp.train.checkpoint_interval == 0:
            embedder_net.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(
                e + 1) + "_batch_id_" + str(batch_id + 1) + ".pth"
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir,
                                           ckpt_model_filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            embedder_net.to(device).train()

    #save model
    embedder_net.eval().cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(
        batch_id + 1) + ".model"
    save_model_path = os.path.join(hp.train.checkpoint_dir,
                                   save_model_filename)
    torch.save(embedder_net.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 3
0
def train(model_path):
    FNULL = open(os.devnull, 'w')
    device = torch.device(hp.device)

    if hp.data.data_preprocessed:
        train_dataset = SpeakerDatasetTIMITPreprocessed(is_training=True)
        test_dataset = SpeakerDatasetTIMITPreprocessed(is_training=False)
    else:
        train_dataset = SpeakerDatasetTIMIT(is_training=True)
        test_dataset = SpeakerDatasetTIMIT(is_training=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=hp.test.N,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder().to(device)
    if hp.train.restore:
        subprocess.call([
            'gsutil', 'cp', 'gs://edinquake/asr/baseline_TIMIT/model_best.pkl',
            model_path
        ],
                        stdout=FNULL,
                        stderr=subprocess.STDOUT)
        embedder_net.load_state_dict(torch.load(model_path))
    ge2e_loss = GE2ELoss(device)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([{
        'params': embedder_net.parameters()
    }, {
        'params': ge2e_loss.parameters()
    }],
                                lr=hp.train.lr)

    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)

    iteration = 0
    best_validate = float('inf')
    print('***Started training at {}***'.format(datetime.now()))
    for e in range(hp.train.epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(e),
                            leave=False,
                            disable=False)
        embedder_net.train()
        # Iterate over the training set
        for batch_id, mel_db_batch in enumerate(progress_bar):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.train.N * hp.train.M),
                                 hp.train.N * hp.train.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.train.N, hp.train.M, embeddings.size(1)))

            #get loss, call backward, step optimizer
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()

            total_loss = total_loss + loss.item()
            iteration += 1

            # Update statistics for progress bar
            progress_bar.set_postfix(iteration=iteration,
                                     loss=loss.item(),
                                     total_loss=total_loss / (batch_id + 1))

        print('| Epoch {:03d}: total_loss {}'.format(e, total_loss))
        # Perform validation
        embedder_net.eval()
        validation_loss = 0
        for batch_id, mel_db_batch in enumerate(test_loader):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.test.N * hp.test.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.test.N * hp.test.M),
                                 hp.test.N * hp.test.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.test.N, hp.test.M, embeddings.size(1)))
            #get loss
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            validation_loss += loss.item()

        validation_loss /= len(test_loader)
        print('validation_loss: {}'.format(validation_loss))
        if validation_loss <= best_validate:
            best_validate = validation_loss
            # Save best
            filename = 'model_best.pkl'
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            subprocess.call([
                'gsutil', 'cp', ckpt_model_path,
                'gs://edinquake/asr/baseline_TIMIT/model_best.pkl'
            ],
                            stdout=FNULL,
                            stderr=subprocess.STDOUT)

        filename = 'model_last.pkl'
        ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename)
        torch.save(embedder_net.state_dict(), ckpt_model_path)