def __init__(self, opt, hidden_dim, calc_accuracy):
        super(Speaker_Loss, self).__init__()

        self.opt = opt
        self.hidden_dim = hidden_dim
        self.calc_accuracy = calc_accuracy

        self.linear_classifier = nn.Sequential(nn.Linear(self.hidden_dim,
                                                         251)).to(opt.device)

        self.label_num = 1
        self.speaker_loss = nn.CrossEntropyLoss()

        # create mapping speaker_id to label
        if torch.cuda.is_available():
            factor = torch.cuda.device_count()
        else:
            factor = 1

        # model is initialized before the dataset is loaded,
        # so we initialize the speaker_id_dict with a separate version of the dataset
        opt.batch_size_multiGPU = opt.batch_size * factor
        _, train_dataset, _, _ = get_dataloader.get_libri_dataloaders(opt)
        self.speaker_id_dict = {}
        for idx, key in enumerate(train_dataset.speaker_dict):
            self.speaker_id_dict[key] = idx
if __name__ == "__main__":

    opt = arg_parser.parse_args()
    arg_parser.create_log_path(opt)

    # set random seeds
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)

    # load model
    model, optimizer = load_audio_model.load_model_and_optimizer(opt)

    # initialize logger
    logs = logger.Logger(opt)

    # get datasets and dataloaders
    train_loader, train_dataset, test_loader, test_dataset = get_dataloader.get_libri_dataloaders(
        opt
    )

    try:
        # Train the model
        train(opt, model)

    except KeyboardInterrupt:
        print("Training got interrupted, saving log-files now.")

    logs.create_log(model)
Example #3
0
        reload_model=True,
        calc_accuracy=True,
        num_GPU=1,
    )
    context_model.eval()

    n_features = context_model.module.reg_hidden

    loss = loss_supervised_speaker.Speaker_Loss(
        opt, n_features, calc_accuracy=True
    )

    optimizer = torch.optim.Adam(loss.parameters(), lr=opt.learning_rate)

    # load dataset
    train_loader, _, test_loader, _ = get_dataloader.get_libri_dataloaders(opt)

    logs = logger.Logger(opt)
    accuracy = 0

    try:
        # Train the model
        train(opt, context_model, loss)

        # Test the model
        result_loss, accuracy = test(opt, context_model, loss, test_loader)

    except KeyboardInterrupt:
        print("Training interrupted, saving log files")

    logs.create_log(loss, accuracy=accuracy, final_test=True, final_loss=result_loss)