def get_train_loader(args, model): """Returns a torch.nn.DataLoader over the training data. Args: args: An argparse.Namespace for the argparse.ArgumentParser returned by get_parser . model: A deepspeech.models.Model . """ if len(args.train_subsets) == 0: logging.debug("no train_subsets specified") return todo_epochs = args.n_epochs - model.completed_epochs if todo_epochs <= 0: logging.debug(" n_epochs <= model.completed_epochs ") return train_cache = os.path.join(args.cachedir, "train") train_dataset = LibriSpeech(root=train_cache, subsets=args.train_subsets, transform=model.transform, target_transform=model.target_transform, download=True) return DataLoader(train_dataset, collate_fn=collate_input_sequences, pin_memory=torch.cuda.is_available(), num_workers=args.train_num_workers, batch_size=args.train_batch_size, shuffle=True)
def get_val_loader(args, model): """Returns a torch.nn.DataLoader over the validation data. Args: args: An argparse.Namespace for the argparse.ArgumentParser returned by get_parser . model: A deepspeech.models.Model . """ if len(args.dev_subsets) == 0: logging.debug("no dev_subsets specified") return if len(args.dev_log) == 0: logging.debug("no dev_log statistics specified") return dev_cache = os.path.join(args.cachedir, "dev") dev_dataset = LibriSpeech(root=dev_cache, subsets=args.dev_subsets, transform=model.transform, target_transform=model.target_transform, download=True) return DataLoader(dev_dataset, collate_fn=collate_input_sequences, pin_memory=torch.cuda.is_available(), num_workers=args.dev_num_workers, batch_size=args.dev_batch_size, shuffle=False)
def get_dev_loader(args, model): dev_cache = os.path.join(args.cachedir, 'dev') dev_dataset = LibriSpeech(root=dev_cache, subsets=args.dev_subsets, transform=model.transform, target_transform=model.target_transform, download=True) return DataLoader(dev_dataset, collate_fn=collate_input_sequences, pin_memory=torch.cuda.is_available(), num_workers=args.dev_num_workers, batch_size=args.dev_batch_size, shuffle=False)