def main(conf): device = "cuda:0" if torch.cuda.is_available() else 'cpu' beta_schedule = "linear" beta_start = 1e-4 beta_end = 2e-2 n_timestep = 1000 conf.distributed = dist.get_world_size() > 1 transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) train_set = MultiResolutionDataset( conf.dataset.path, transform, conf.dataset.resolution ) train_sampler = dist.data_sampler( train_set, shuffle=True, distributed=conf.distributed ) train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler) model = UNet( conf.model.in_channel, conf.model.channel, channel_multiplier=conf.model.channel_multiplier, n_res_blocks=conf.model.n_res_blocks, attn_strides=conf.model.attn_strides, dropout=conf.model.dropout, fold=conf.model.fold, ) model = model.to(device) ema = UNet( conf.model.in_channel, conf.model.channel, channel_multiplier=conf.model.channel_multiplier, n_res_blocks=conf.model.n_res_blocks, attn_strides=conf.model.attn_strides, dropout=conf.model.dropout, fold=conf.model.fold, ) ema = ema.to(device) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = conf.training.optimizer.make(model.parameters()) scheduler = conf.training.scheduler.make(optimizer) betas = make_beta_schedule(beta_schedule, beta_start, beta_end, n_timestep) diffusion = GaussianDiffusion(betas).to(device) train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device)
def main(conf): wandb = None if dist.is_primary() and conf.evaluate.wandb: wandb = load_wandb() wandb.init(project="denoising diffusion") device = "cuda" beta_schedule = "linear" conf.distributed = dist.get_world_size() > 1 transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) train_set = MultiResolutionDataset(conf.dataset.path, transform, conf.dataset.resolution) train_sampler = dist.data_sampler(train_set, shuffle=True, distributed=conf.distributed) train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler) model = conf.model.make() model = model.to(device) ema = conf.model.make() ema = ema.to(device) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = conf.training.optimizer.make(model.parameters()) scheduler = conf.training.scheduler.make(optimizer) if conf.ckpt is not None: ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage) if conf.distributed: model.module.load_state_dict(ckpt["model"]) else: model.load_state_dict(ckpt["model"]) ema.load_state_dict(ckpt["ema"]) betas = conf.diffusion.beta_schedule.make() diffusion = GaussianDiffusion(betas).to(device) train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device, wandb)
def make_dataloader(train_set, valid_set, batch, distributed, n_worker): batch_size = batch // dist.get_world_size() train_sampler = dist.data_sampler(train_set, shuffle=True, distributed=distributed) train_loader = DataLoader( train_set, batch_size=batch_size, sampler=train_sampler, num_workers=n_worker ) valid_loader = DataLoader( valid_set, batch_size=batch_size, sampler=dist.data_sampler(valid_set, shuffle=False, distributed=distributed), num_workers=n_worker, ) return train_loader, valid_loader, train_sampler
def main(conf): conf.distributed = dist.get_world_size() > 1 device = "cuda" if dist.is_primary(): from pprint import pprint pprint(conf.dict()) if dist.is_primary() and conf.evaluate.wandb: wandb = load_wandb() wandb.init(project="asr") else: wandb = None with open("trainval_indices.pkl", "rb") as f: split_indices = pickle.load(f) train_set = ASRDataset( conf.dataset.path, indices=split_indices["train"], alignment=conf.dataset.alignment, ) valid_set = ASRDataset(conf.dataset.path, indices=split_indices["val"]) train_sampler = dist.data_sampler(train_set, shuffle=True, distributed=conf.distributed) valid_sampler = dist.data_sampler(valid_set, shuffle=False, distributed=conf.distributed) if conf.training.batch_sampler is not None: train_lens = [] for i in split_indices["train"]: train_lens.append(train_set.mel_lengths[i]) opts = conf.training.batch_sampler bins = ((opts.base**np.linspace(opts.start, 1, 2 * opts.k + 1)) * 1000).tolist() groups, bins, n_samples = create_groups(train_lens, bins) batch_sampler = GroupedBatchSampler( train_sampler, groups, conf.training.dataloader.batch_size) conf.training.dataloader.batch_size = 1 train_loader = conf.training.dataloader.make( train_set, batch_sampler=batch_sampler, collate_fn=collate_data_imputer) else: train_loader = conf.training.dataloader.make( train_set, collate_fn=collate_data_imputer) valid_loader = conf.training.dataloader.make(valid_set, sampler=valid_sampler, collate_fn=collate_data) model = Transformer( conf.dataset.n_vocab, conf.model.delta, conf.dataset.n_mels, conf.model.feature_channel, conf.model.dim, conf.model.dim_ff, conf.model.n_layer, conf.model.n_head, conf.model.dropout, ).to(device) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = conf.training.optimizer.make(model.parameters()) scheduler = conf.training.scheduler.make(optimizer) if conf.ckpt is not None: ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage) model_p = model if conf.distributed: model_p = model.module model_p.load_state_dict(ckpt["model"]) # scheduler.load_state_dict(ckpt["scheduler"]) model_p.copy_embed(1) model_training = ModelTraining( model, optimizer, scheduler, train_set, train_loader, valid_loader, device, wandb, ) train(conf, model_training)