def train(model_path):
    device = torch.device(hp.device)

    train_dataset = SpeakerDatasetTIMITPreprocessed()
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)

    embedder_net = SpeechEmbedder().to(device)
    if hp.train.restore:
        embedder_net.load_state_dict(torch.load(model_path))
    ge2e_loss = GE2ELoss(device)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([{
        'params': embedder_net.parameters()
    }, {
        'params': ge2e_loss.parameters()
    }],
                                lr=hp.train.lr)

    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)

    embedder_net.train()
    iteration = 0
    for e in range(hp.train.epochs):
        total_loss = 0
        for batch_id, mel_db_batch in enumerate(train_loader):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.train.N * hp.train.M),
                                 hp.train.N * hp.train.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.train.N, hp.train.M, embeddings.size(1)))

            #get loss, call backward, step optimizer
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()

            total_loss = total_loss + loss
            iteration += 1
            if (batch_id + 1) % hp.train.log_interval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(
                    time.ctime(), e + 1, batch_id + 1,
                    len(train_dataset) // hp.train.N, iteration, loss,
                    total_loss / (batch_id + 1))
                print(mesg)
                if hp.train.log_file is not None:
                    '''
                    if os.path.exists(hp.train.log_file):
                        os.mknod(hp.train.log_file)
                    '''
                    with open(hp.train.log_file, 'w') as f:
                        f.write(mesg)

        if hp.train.checkpoint_dir is not None and (
                e + 1) % hp.train.checkpoint_interval == 0:
            embedder_net.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(
                e + 1) + "_batch_id_" + str(batch_id + 1) + ".pth"
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir,
                                           ckpt_model_filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            embedder_net.to(device).train()

    #save model
    embedder_net.eval().cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(
        batch_id + 1) + ".model"
    save_model_path = os.path.join(hp.train.checkpoint_dir,
                                   save_model_filename)
    torch.save(embedder_net.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
import csv
import sys

#assumes you are calling SVE repo from outside (ie LegalSpeech repo)
sys.path.append("./SpeakerVerificationEmbedding/src")

from hparam import hparam_SCOTUS as hp
from speech_embedder_net import SpeechEmbedder
from VAD_segments import VAD_chunk
from utils import concat_segs, get_STFTs, align_embeddings

#initialize SpeechEmbedder
embedder_net = SpeechEmbedder()
print(hp.model.model_path)
embedder_net.load_state_dict(torch.load(hp.model.model_path))
embedder_net.to(hp.device)

#dataset path
case_path = glob.glob(os.path.dirname(hp.unprocessed_data))

min_va = 2  # minimum voice activity length
label = 20  # unknown speaker label counter (leave room for 20 judges)

cnt = 0  # counter for judge_dict
judge_dict = dict()

# File Use Tracking

verbose = hp.data.verbose
embedder_net.eval()
'''
示例#3
0
            iteration += 1

            if (batch_id + 1) % hp.train.log_interval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(time.ctime(), e+1,
                        batch_id+1, len(train_dataset)//hp.train.N, iteration,loss, total_loss / (batch_id + 1))
                print(mesg)
                if hp.train.log_file is not None:
                    with open(hp.train.log_file,'a') as f:
                        f.write(mesg)
                    
        if hp.train.checkpoint_dir is not None and (e + 1) % hp.train.checkpoint_interval == 0:
            embedder_net.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e+1) + "_batch_id_" + str(batch_id+1) + ".pth"
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir, ckpt_model_filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            embedder_net.to(device).train()
    
    ###save model
    embedder_net.eval().cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(batch_id + 1) + ".model"
    save_model_path = os.path.join(hp.train.checkpoint_dir, save_model_filename)
    torch.save(embedder_net.state_dict(), save_model_path)
    
    print("\nDone, trained model saved at", save_model_path)
    with torch.no_grad():
        avg_EER = 0
        i = 0
        for e in range(hp.test.epochs):
            batch_avg_EER = 0
            for batch_id, mel_db_batch in enumerate(test_loader):
                assert hp.test.M % 2 == 0
def main():
    args = get_args()
    if args.corpus == 'CAAML':
        dataset = CAAMLDataset(args.data_path, args.save_path, args.split)
    elif args.corpus == 'ICSI':
        dataset = ICSIDataset(args.data_path, args.save_path)
    elif args.corpus == 'TIMIT':
        print('Dataset not yet implemented...')
        exit()

    # Load speech embedder net
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(hp.model.model_path))
    embedder_net = embedder_net.to(device)
    embedder_net.eval()

    all_seqs = []
    all_cids = []
    all_times = []
    sessions = []
    for path in dataset.data_path:
        print('\n============== Processing {} ============== '.format(path))

        # Get session name
        session = dataset.get_session(path)
        if session is None:
            print('ERR: Session not found in any split, skipping...')
            continue
 
        # Get annotations
        annotations = dataset.get_annotations(path)
        if annotations is None:
            print('ERR: No suitable annotations found, skipping...')
            continue

        # Segment the audio with VAD
        times, segments = timeit(VAD_chunk, msg='Getting VAD chunks')(hp.data.aggressiveness, path)
        if segments == []:
            print('ERR: No segments found, skipping...')
            continue

        # Concatenate segments
        segments, times = concat(segments, times)

        # Get STFT frames
        frames, times = get_STFTs(segments, times)
        frames = np.stack(frames, axis=2)
        frames = torch.tensor(np.transpose(frames, axes=(2,1,0))).to(device)

        # Get speaker embeddings
        embeddings = get_speaker_embeddings(embedder_net, frames)

        # Align speaker embeddings into a standard sequence of embeddings
        sequence, times = align_embeddings(embeddings.cpu().detach().numpy(), times)
        
        # Get cluster ids for each frame
        cluster_ids = get_cluster_ids(times, dataset, annotations)

        # Add the sequence and cluster ids to the list of all sessions
        all_seqs.append(sequence) 
        all_cids.append(cluster_ids)
        all_times.append(times)
        sessions.append(session)

    # Save split dataset 
    dataset.save(all_seqs, all_cids, all_times, sessions)