コード例 #1
0
def get_train_loader(args, model):
    """Returns a  torch.nn.DataLoader over the training data.

    Args:
        args: An  argparse.Namespace  for the  argparse.ArgumentParser 
            returned by  get_parser .
        model: A  deepspeech.models.Model .
    """
    if len(args.train_subsets) == 0:
        logging.debug("no  train_subsets  specified")
        return

    todo_epochs = args.n_epochs - model.completed_epochs
    if todo_epochs <= 0:
        logging.debug(" n_epochs  <=  model.completed_epochs ")
        return

    train_cache = os.path.join(args.cachedir, "train")
    train_dataset = LibriSpeech(root=train_cache,
                                subsets=args.train_subsets,
                                transform=model.transform,
                                target_transform=model.target_transform,
                                download=True)

    return DataLoader(train_dataset,
                      collate_fn=collate_input_sequences,
                      pin_memory=torch.cuda.is_available(),
                      num_workers=args.train_num_workers,
                      batch_size=args.train_batch_size,
                      shuffle=True)
コード例 #2
0
def get_val_loader(args, model):
    """Returns a  torch.nn.DataLoader over the validation data.

    Args:
        args: An  argparse.Namespace  for the  argparse.ArgumentParser 
            returned by  get_parser .
        model: A  deepspeech.models.Model .
    """
    if len(args.dev_subsets) == 0:
        logging.debug("no  dev_subsets  specified")
        return

    if len(args.dev_log) == 0:
        logging.debug("no  dev_log  statistics specified")
        return

    dev_cache = os.path.join(args.cachedir, "dev")
    dev_dataset = LibriSpeech(root=dev_cache,
                              subsets=args.dev_subsets,
                              transform=model.transform,
                              target_transform=model.target_transform,
                              download=True)

    return DataLoader(dev_dataset,
                      collate_fn=collate_input_sequences,
                      pin_memory=torch.cuda.is_available(),
                      num_workers=args.dev_num_workers,
                      batch_size=args.dev_batch_size,
                      shuffle=False)
コード例 #3
0
def get_dev_loader(args, model):
    dev_cache = os.path.join(args.cachedir, 'dev')
    dev_dataset = LibriSpeech(root=dev_cache,
                              subsets=args.dev_subsets,
                              transform=model.transform,
                              target_transform=model.target_transform,
                              download=True)

    return DataLoader(dev_dataset,
                      collate_fn=collate_input_sequences,
                      pin_memory=torch.cuda.is_available(),
                      num_workers=args.dev_num_workers,
                      batch_size=args.dev_batch_size,
                      shuffle=False)