Exemple #1
1
def train_distributed(replica_id, replica_count, port, args, params):
  os.environ['MASTER_ADDR'] = 'localhost'
  os.environ['MASTER_PORT'] = str(port)
  torch.distributed.init_process_group('nccl', rank=replica_id, world_size=replica_count)

  device = torch.device('cuda', replica_id)
  torch.cuda.set_device(device)
  model = WaveGrad(params).to(device)
  model = DistributedDataParallel(model, device_ids=[replica_id])
  _train_impl(replica_id, model, dataset_from_path(args.data_dirs, params, is_distributed=True), args, params)
Exemple #2
0
def train(args, params):
  dataset = dataset_from_path(args.data_dirs, params)
  model = WaveGrad(params).cuda()
  _train_impl(0, model, dataset, args, params)
Exemple #3
0
def main(args):
    train(dataset_from_path(args.data_dirs, params), args, params)