def main(conf): device = "cuda:0" if torch.cuda.is_available() else 'cpu' beta_schedule = "linear" beta_start = 1e-4 beta_end = 2e-2 n_timestep = 1000 conf.distributed = dist.get_world_size() > 1 transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) train_set = MultiResolutionDataset( conf.dataset.path, transform, conf.dataset.resolution ) train_sampler = dist.data_sampler( train_set, shuffle=True, distributed=conf.distributed ) train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler) model = UNet( conf.model.in_channel, conf.model.channel, channel_multiplier=conf.model.channel_multiplier, n_res_blocks=conf.model.n_res_blocks, attn_strides=conf.model.attn_strides, dropout=conf.model.dropout, fold=conf.model.fold, ) model = model.to(device) ema = UNet( conf.model.in_channel, conf.model.channel, channel_multiplier=conf.model.channel_multiplier, n_res_blocks=conf.model.n_res_blocks, attn_strides=conf.model.attn_strides, dropout=conf.model.dropout, fold=conf.model.fold, ) ema = ema.to(device) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = conf.training.optimizer.make(model.parameters()) scheduler = conf.training.scheduler.make(optimizer) betas = make_beta_schedule(beta_schedule, beta_start, beta_end, n_timestep) diffusion = GaussianDiffusion(betas).to(device) train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device)
def main(conf): wandb = None if dist.is_primary() and conf.evaluate.wandb: wandb = load_wandb() wandb.init(project="denoising diffusion") device = "cuda" beta_schedule = "linear" conf.distributed = dist.get_world_size() > 1 transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) train_set = MultiResolutionDataset(conf.dataset.path, transform, conf.dataset.resolution) train_sampler = dist.data_sampler(train_set, shuffle=True, distributed=conf.distributed) train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler) model = conf.model.make() model = model.to(device) ema = conf.model.make() ema = ema.to(device) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = conf.training.optimizer.make(model.parameters()) scheduler = conf.training.scheduler.make(optimizer) if conf.ckpt is not None: ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage) if conf.distributed: model.module.load_state_dict(ckpt["model"]) else: model.load_state_dict(ckpt["model"]) ema.load_state_dict(ckpt["ema"]) betas = conf.diffusion.beta_schedule.make() diffusion = GaussianDiffusion(betas).to(device) train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device, wandb)
def main(conf): device = "cuda" conf.distributed = conf.n_gpu > 1 torch.backends.cudnn.benchmark = True logger = get_logger(mode=conf.logger) logger.info(conf.dict()) model = conf.arch.make().to(device) model_ema = conf.arch.make().to(device) logger.info(model) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) model_module = model.module accumulate(model_ema, model_module, 0) else: model_module = model accumulate(model_ema, model, 0) grad_accum = conf.training.grad_accumulation if conf.training.progressive.step > 0: progressive_stage = 0 train_loader, valid_loader, train_sampler, grad_accum = make_progressive_loader( progressive_stage, model_module, conf) else: train_set, valid_set = make_dataset( conf.dataset_path, conf.training.train_size, conf.training.valid_size, { "n_augment": conf.training.randaug_layer, "magnitude": conf.training.randaug_magnitude, "increasing": conf.training.randaug_increasing, "magnitude_std": conf.training.randaug_magnitude_std, "cutout": conf.training.randaug_cutout, }, { "mixup": conf.training.mixup, "cutmix": conf.training.cutmix, "mix_before_aug": conf.training.mix_before_aug, }, conf.training.erasing, ) batch_size = conf.training.dataloader.batch_size // grad_accum train_loader, valid_loader, train_sampler = make_dataloader( train_set, valid_set, batch_size, conf.distributed, conf.training.dataloader.num_workers, ) criterion_train = MixLoss(eps=0.1) criterion_valid = nn.CrossEntropyLoss() parameters, names = add_weight_decay( model.named_parameters(), conf.training.weight_decay, wd_skip_fn(conf.training.wd_skip), ) optimizer = make_optimizer(conf.training, parameters) epoch_len = math.ceil(len(train_loader) / grad_accum) scheduler = make_scheduler(conf.training, optimizer, epoch_len) step = 0 scaler = amp.GradScaler(enabled=conf.fp16) checker = conf.checker.make() for epoch in range(conf.training.epoch): if conf.distributed: train_sampler.set_epoch(epoch) train( conf, step, epoch, train_loader, model, model_ema, criterion_train, optimizer, scheduler, scaler, grad_accum, ) step += epoch_len if conf.training.ema == 0: prec1, prec5, losses = valid(conf, valid_loader, model_module, criterion_valid) else: prec1, prec5, losses = valid(conf, valid_loader, model_ema, criterion_valid) checker.log( step=epoch + 1, prec1=prec1, prec5=prec5, loss=losses.avg, lr=optimizer.param_groups[0]["lr"], ) try: checker.checkpoint( { "model": model_module.state_dict(), "ema": model_ema.state_dict(), "scheduler": scheduler.state_dict(), "optimizer": optimizer.state_dict(), "conf": conf.dict(), }, f"epoch-{str(epoch + 1).zfill(3)}.pt", ) except Exception as e: print(e) if (conf.training.progressive.step > 0 and (epoch + 1) % conf.training.progressive.step == 0): progressive_stage += 1 if (progressive_stage < conf.training.epoch // conf.training.progressive.step): train_loader, valid_loader, train_sampler, grad_accum = make_progressive_loader( progressive_stage, model_module, conf)
def main(conf): conf.distributed = dist.get_world_size() > 1 device = "cuda" if dist.is_primary(): from pprint import pprint pprint(conf.dict()) if dist.is_primary() and conf.evaluate.wandb: wandb = load_wandb() wandb.init(project="asr") else: wandb = None with open("trainval_indices.pkl", "rb") as f: split_indices = pickle.load(f) train_set = ASRDataset( conf.dataset.path, indices=split_indices["train"], alignment=conf.dataset.alignment, ) valid_set = ASRDataset(conf.dataset.path, indices=split_indices["val"]) train_sampler = dist.data_sampler(train_set, shuffle=True, distributed=conf.distributed) valid_sampler = dist.data_sampler(valid_set, shuffle=False, distributed=conf.distributed) if conf.training.batch_sampler is not None: train_lens = [] for i in split_indices["train"]: train_lens.append(train_set.mel_lengths[i]) opts = conf.training.batch_sampler bins = ((opts.base**np.linspace(opts.start, 1, 2 * opts.k + 1)) * 1000).tolist() groups, bins, n_samples = create_groups(train_lens, bins) batch_sampler = GroupedBatchSampler( train_sampler, groups, conf.training.dataloader.batch_size) conf.training.dataloader.batch_size = 1 train_loader = conf.training.dataloader.make( train_set, batch_sampler=batch_sampler, collate_fn=collate_data_imputer) else: train_loader = conf.training.dataloader.make( train_set, collate_fn=collate_data_imputer) valid_loader = conf.training.dataloader.make(valid_set, sampler=valid_sampler, collate_fn=collate_data) model = Transformer( conf.dataset.n_vocab, conf.model.delta, conf.dataset.n_mels, conf.model.feature_channel, conf.model.dim, conf.model.dim_ff, conf.model.n_layer, conf.model.n_head, conf.model.dropout, ).to(device) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = conf.training.optimizer.make(model.parameters()) scheduler = conf.training.scheduler.make(optimizer) if conf.ckpt is not None: ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage) model_p = model if conf.distributed: model_p = model.module model_p.load_state_dict(ckpt["model"]) # scheduler.load_state_dict(ckpt["scheduler"]) model_p.copy_embed(1) model_training = ModelTraining( model, optimizer, scheduler, train_set, train_loader, valid_loader, device, wandb, ) train(conf, model_training)
def main(conf): device = "cuda" conf.distributed = conf.n_gpu > 1 torch.backends.cudnn.benchmark = True model = conf.arch.make().to(device) model_ema = conf.arch.make().to(device) if conf.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) model_module = model.module accumulate(model_ema, model_module, 0) else: model_module = model accumulate(model_ema, model, 0) if conf.training.progressive.step > 0: progressive_stage = 0 train_loader, valid_loader = make_progressive_loader( progressive_stage, model_module, conf) else: train_set, valid_set = make_dataset( conf.dataset_path, conf.training.train_size, conf.training.valid_size, { "n_augment": conf.training.randaug_layer, "magnitude": conf.training.magnitude, "increasing": conf.training.randaug_increasing, "magnitude_std": conf.training.randaug_magnitude_std, }, conf.training.mixup, conf.training.cutmix, ) train_loader, valid_loader = make_dataloader( train_set, valid_set, conf.training.dataloader.batch_size, conf.distributed, conf.training.dataloader.num_workers, ) criterion_train = MixLoss(eps=0.1) criterion_valid = nn.CrossEntropyLoss() parameters, names = add_weight_decay( model.named_parameters(), conf.training.weight_decay, wd_skip_fn(conf.training.wd_skip), ) optimizer = make_optimizer(conf.training, parameters) epoch_len = len(train_loader) scheduler = make_scheduler(conf.training, optimizer, epoch_len) step = 0 def checker_save(filename, *args): torch.save( { "model": model_module.state_dict(), "ema": model_ema.state_dict(), "scheduler": scheduler.state_dict(), "optimizer": optimizer.state_dict(), "conf": conf, }, filename, ) checker = conf.checker.make(checker_save) checker.save("test") for epoch in range(conf.training.epoch): train( conf, step, epoch, train_loader, model, model_ema, criterion_train, optimizer, scheduler, ) step += epoch_len if conf.training.ema == 0: prec1, prec5, losses = valid(conf, valid_loader, model_module, criterion_valid) else: prec1, prec5, losses = valid(conf, valid_loader, model_ema, criterion_valid) checker.log( step=epoch + 1, prec1=prec1, prec5=prec5, loss=losses.avg, lr=optimizer.param_groups[0]["lr"], ) checker.save(f"epoch-{str(epoch + 1).zfill(3)}") if (conf.training.progressive.step > 0 and (epoch + 1) % conf.training.progressive.step == 0): progressive_stage += 1 if (progressive_stage < conf.training.epoch // conf.training.progressive.step): train_loader, valid_loader = make_progressive_loader( progressive_stage, model_module, conf)
def main(conf): device = "cuda" conf.distributed = conf.n_gpu > 1 torch.backends.cudnn.benchmark = True logger = get_logger(mode=conf.logger) logger.info(conf.dict()) student = conf.arch.make().to(device) student.set_drop_path(conf.task.student_drop_path) teacher = conf.arch.make().to(device) logger.info(student) if conf.distributed: teacher = nn.parallel.DistributedDataParallel( teacher, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) student = nn.parallel.DistributedDataParallel( student, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) teacher_module = teacher.module student_module = student.module teacher_module.load_state_dict(student_module.state_dict()) else: teacher_module = teacher student_module = student teacher_module.load_state_dict(student.state_dict()) for p in teacher.parameters(): p.requires_grad = False grad_accum = conf.training.grad_accumulation train_set, valid_set = make_augment_dataset( conf.dataset_path, DINOAugment( conf.task.global_crop_size, conf.task.local_crop_size, conf.task.global_crop_scale, conf.task.local_crop_scale, conf.task.n_local_crop, ), None, ) batch_size = conf.training.dataloader.batch_size // grad_accum train_loader, valid_loader, train_sampler = make_dataloader( train_set, valid_set, batch_size, conf.distributed, conf.training.dataloader.num_workers, ) criterion_train = DINOLoss( conf.arch.dim_head_out, conf.task.n_local_crop + 2, conf.task.warmup_teacher_temperature, conf.task.teacher_temperature, conf.task.warmup_teacher_temperature_epoch, conf.training.epoch, ).to(device) parameters, names = add_weight_decay( student.named_parameters(), conf.training.weight_decay, wd_skip_fn(conf.training.wd_skip), ) def make_scheduler(train_conf, optimizer, epoch_len): warmup = train_conf.scheduler.warmup * epoch_len n_iter = epoch_len * train_conf.epoch lr = train_conf.base_lr * train_conf.dataloader.batch_size / 256 if train_conf.scheduler.type == "exp_epoch": return train_conf.scheduler.make(optimizer, epoch_len, lr=lr, max_iter=train_conf.epoch, warmup=warmup) else: return train_conf.scheduler.make(optimizer, lr=lr, n_iter=n_iter, warmup=warmup) optimizer = make_optimizer(conf.training, parameters) epoch_len = math.ceil(len(train_loader) / grad_accum) scheduler = make_scheduler(conf.training, optimizer, epoch_len) wd_schedule = cosine_schedule( conf.training.weight_decay, conf.task.weight_decay_end, epoch_len * conf.training.epoch, ) momentum_schedule = cosine_schedule(conf.task.teacher_momentum, 1, epoch_len * conf.training.epoch) scaler = amp.GradScaler(enabled=conf.fp16) checker = conf.checker.make() step = 0 for epoch in range(conf.training.epoch): if conf.distributed: train_sampler.set_epoch(epoch) train( conf, step, epoch, train_loader, teacher, student, criterion_train, optimizer, scheduler, wd_schedule, momentum_schedule, scaler, grad_accum, checker, ) step += epoch_len try: checker.checkpoint( { "student": student_module.state_dict(), "teacher": teacher_module.state_dict(), "scheduler": scheduler.state_dict(), "optimizer": optimizer.state_dict(), "conf": conf.dict(), }, f"epoch-{str(epoch + 1).zfill(3)}.pt", ) except Exception as e: print(e)