def main(conf):
    device = "cuda:0" if torch.cuda.is_available() else 'cpu'
    beta_schedule = "linear"
    beta_start = 1e-4
    beta_end = 2e-2
    n_timestep = 1000

    conf.distributed = dist.get_world_size() > 1

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    train_set = MultiResolutionDataset(
        conf.dataset.path, transform, conf.dataset.resolution
    )
    train_sampler = dist.data_sampler(
        train_set, shuffle=True, distributed=conf.distributed
    )
    train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler)

    model = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    model = model.to(device)
    ema = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    ema = ema.to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    betas = make_beta_schedule(beta_schedule, beta_start, beta_end, n_timestep)
    diffusion = GaussianDiffusion(betas).to(device)

    train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device)
def main(conf):
    wandb = None
    if dist.is_primary() and conf.evaluate.wandb:
        wandb = load_wandb()
        wandb.init(project="denoising diffusion")

    device = "cuda"
    beta_schedule = "linear"

    conf.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    train_set = MultiResolutionDataset(conf.dataset.path, transform,
                                       conf.dataset.resolution)
    train_sampler = dist.data_sampler(train_set,
                                      shuffle=True,
                                      distributed=conf.distributed)
    train_loader = conf.training.dataloader.make(train_set,
                                                 sampler=train_sampler)

    model = conf.model.make()
    model = model.to(device)
    ema = conf.model.make()
    ema = ema.to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    if conf.ckpt is not None:
        ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage)

        if conf.distributed:
            model.module.load_state_dict(ckpt["model"])

        else:
            model.load_state_dict(ckpt["model"])

        ema.load_state_dict(ckpt["ema"])

    betas = conf.diffusion.beta_schedule.make()
    diffusion = GaussianDiffusion(betas).to(device)

    train(conf, train_loader, model, ema, diffusion, optimizer, scheduler,
          device, wandb)
Esempio n. 3
0
def main(conf):
    device = "cuda"
    conf.distributed = conf.n_gpu > 1
    torch.backends.cudnn.benchmark = True

    logger = get_logger(mode=conf.logger)
    logger.info(conf.dict())

    model = conf.arch.make().to(device)
    model_ema = conf.arch.make().to(device)

    logger.info(model)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )
        model_module = model.module

        accumulate(model_ema, model_module, 0)

    else:
        model_module = model
        accumulate(model_ema, model, 0)

    grad_accum = conf.training.grad_accumulation

    if conf.training.progressive.step > 0:
        progressive_stage = 0
        train_loader, valid_loader, train_sampler, grad_accum = make_progressive_loader(
            progressive_stage, model_module, conf)

    else:
        train_set, valid_set = make_dataset(
            conf.dataset_path,
            conf.training.train_size,
            conf.training.valid_size,
            {
                "n_augment": conf.training.randaug_layer,
                "magnitude": conf.training.randaug_magnitude,
                "increasing": conf.training.randaug_increasing,
                "magnitude_std": conf.training.randaug_magnitude_std,
                "cutout": conf.training.randaug_cutout,
            },
            {
                "mixup": conf.training.mixup,
                "cutmix": conf.training.cutmix,
                "mix_before_aug": conf.training.mix_before_aug,
            },
            conf.training.erasing,
        )

        batch_size = conf.training.dataloader.batch_size // grad_accum

        train_loader, valid_loader, train_sampler = make_dataloader(
            train_set,
            valid_set,
            batch_size,
            conf.distributed,
            conf.training.dataloader.num_workers,
        )

    criterion_train = MixLoss(eps=0.1)
    criterion_valid = nn.CrossEntropyLoss()

    parameters, names = add_weight_decay(
        model.named_parameters(),
        conf.training.weight_decay,
        wd_skip_fn(conf.training.wd_skip),
    )

    optimizer = make_optimizer(conf.training, parameters)
    epoch_len = math.ceil(len(train_loader) / grad_accum)
    scheduler = make_scheduler(conf.training, optimizer, epoch_len)

    step = 0

    scaler = amp.GradScaler(enabled=conf.fp16)

    checker = conf.checker.make()

    for epoch in range(conf.training.epoch):
        if conf.distributed:
            train_sampler.set_epoch(epoch)

        train(
            conf,
            step,
            epoch,
            train_loader,
            model,
            model_ema,
            criterion_train,
            optimizer,
            scheduler,
            scaler,
            grad_accum,
        )
        step += epoch_len

        if conf.training.ema == 0:
            prec1, prec5, losses = valid(conf, valid_loader, model_module,
                                         criterion_valid)

        else:
            prec1, prec5, losses = valid(conf, valid_loader, model_ema,
                                         criterion_valid)

        checker.log(
            step=epoch + 1,
            prec1=prec1,
            prec5=prec5,
            loss=losses.avg,
            lr=optimizer.param_groups[0]["lr"],
        )
        try:
            checker.checkpoint(
                {
                    "model": model_module.state_dict(),
                    "ema": model_ema.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "conf": conf.dict(),
                },
                f"epoch-{str(epoch + 1).zfill(3)}.pt",
            )
        except Exception as e:
            print(e)

        if (conf.training.progressive.step > 0
                and (epoch + 1) % conf.training.progressive.step == 0):
            progressive_stage += 1

            if (progressive_stage <
                    conf.training.epoch // conf.training.progressive.step):
                train_loader, valid_loader, train_sampler, grad_accum = make_progressive_loader(
                    progressive_stage, model_module, conf)
def main(conf):
    conf.distributed = dist.get_world_size() > 1

    device = "cuda"

    if dist.is_primary():
        from pprint import pprint

        pprint(conf.dict())

    if dist.is_primary() and conf.evaluate.wandb:
        wandb = load_wandb()
        wandb.init(project="asr")

    else:
        wandb = None

    with open("trainval_indices.pkl", "rb") as f:
        split_indices = pickle.load(f)

    train_set = ASRDataset(
        conf.dataset.path,
        indices=split_indices["train"],
        alignment=conf.dataset.alignment,
    )
    valid_set = ASRDataset(conf.dataset.path, indices=split_indices["val"])

    train_sampler = dist.data_sampler(train_set,
                                      shuffle=True,
                                      distributed=conf.distributed)
    valid_sampler = dist.data_sampler(valid_set,
                                      shuffle=False,
                                      distributed=conf.distributed)

    if conf.training.batch_sampler is not None:
        train_lens = []

        for i in split_indices["train"]:
            train_lens.append(train_set.mel_lengths[i])

        opts = conf.training.batch_sampler

        bins = ((opts.base**np.linspace(opts.start, 1, 2 * opts.k + 1)) *
                1000).tolist()
        groups, bins, n_samples = create_groups(train_lens, bins)
        batch_sampler = GroupedBatchSampler(
            train_sampler, groups, conf.training.dataloader.batch_size)

        conf.training.dataloader.batch_size = 1
        train_loader = conf.training.dataloader.make(
            train_set,
            batch_sampler=batch_sampler,
            collate_fn=collate_data_imputer)

    else:
        train_loader = conf.training.dataloader.make(
            train_set, collate_fn=collate_data_imputer)

    valid_loader = conf.training.dataloader.make(valid_set,
                                                 sampler=valid_sampler,
                                                 collate_fn=collate_data)

    model = Transformer(
        conf.dataset.n_vocab,
        conf.model.delta,
        conf.dataset.n_mels,
        conf.model.feature_channel,
        conf.model.dim,
        conf.model.dim_ff,
        conf.model.n_layer,
        conf.model.n_head,
        conf.model.dropout,
    ).to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    if conf.ckpt is not None:
        ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage)

        model_p = model

        if conf.distributed:
            model_p = model.module

        model_p.load_state_dict(ckpt["model"])
        # scheduler.load_state_dict(ckpt["scheduler"])

        model_p.copy_embed(1)

    model_training = ModelTraining(
        model,
        optimizer,
        scheduler,
        train_set,
        train_loader,
        valid_loader,
        device,
        wandb,
    )

    train(conf, model_training)
Esempio n. 5
0
def main(conf):
    device = "cuda"
    conf.distributed = conf.n_gpu > 1
    torch.backends.cudnn.benchmark = True

    model = conf.arch.make().to(device)
    model_ema = conf.arch.make().to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )
        model_module = model.module

        accumulate(model_ema, model_module, 0)

    else:
        model_module = model
        accumulate(model_ema, model, 0)

    if conf.training.progressive.step > 0:
        progressive_stage = 0
        train_loader, valid_loader = make_progressive_loader(
            progressive_stage, model_module, conf)

    else:
        train_set, valid_set = make_dataset(
            conf.dataset_path,
            conf.training.train_size,
            conf.training.valid_size,
            {
                "n_augment": conf.training.randaug_layer,
                "magnitude": conf.training.magnitude,
                "increasing": conf.training.randaug_increasing,
                "magnitude_std": conf.training.randaug_magnitude_std,
            },
            conf.training.mixup,
            conf.training.cutmix,
        )
        train_loader, valid_loader = make_dataloader(
            train_set,
            valid_set,
            conf.training.dataloader.batch_size,
            conf.distributed,
            conf.training.dataloader.num_workers,
        )

    criterion_train = MixLoss(eps=0.1)
    criterion_valid = nn.CrossEntropyLoss()

    parameters, names = add_weight_decay(
        model.named_parameters(),
        conf.training.weight_decay,
        wd_skip_fn(conf.training.wd_skip),
    )

    optimizer = make_optimizer(conf.training, parameters)
    epoch_len = len(train_loader)
    scheduler = make_scheduler(conf.training, optimizer, epoch_len)

    step = 0

    def checker_save(filename, *args):
        torch.save(
            {
                "model": model_module.state_dict(),
                "ema": model_ema.state_dict(),
                "scheduler": scheduler.state_dict(),
                "optimizer": optimizer.state_dict(),
                "conf": conf,
            },
            filename,
        )

    checker = conf.checker.make(checker_save)
    checker.save("test")

    for epoch in range(conf.training.epoch):
        train(
            conf,
            step,
            epoch,
            train_loader,
            model,
            model_ema,
            criterion_train,
            optimizer,
            scheduler,
        )
        step += epoch_len

        if conf.training.ema == 0:
            prec1, prec5, losses = valid(conf, valid_loader, model_module,
                                         criterion_valid)

        else:
            prec1, prec5, losses = valid(conf, valid_loader, model_ema,
                                         criterion_valid)

        checker.log(
            step=epoch + 1,
            prec1=prec1,
            prec5=prec5,
            loss=losses.avg,
            lr=optimizer.param_groups[0]["lr"],
        )
        checker.save(f"epoch-{str(epoch + 1).zfill(3)}")

        if (conf.training.progressive.step > 0
                and (epoch + 1) % conf.training.progressive.step == 0):
            progressive_stage += 1

            if (progressive_stage <
                    conf.training.epoch // conf.training.progressive.step):
                train_loader, valid_loader = make_progressive_loader(
                    progressive_stage, model_module, conf)
def main(conf):
    device = "cuda"
    conf.distributed = conf.n_gpu > 1
    torch.backends.cudnn.benchmark = True

    logger = get_logger(mode=conf.logger)
    logger.info(conf.dict())

    student = conf.arch.make().to(device)
    student.set_drop_path(conf.task.student_drop_path)
    teacher = conf.arch.make().to(device)

    logger.info(student)

    if conf.distributed:
        teacher = nn.parallel.DistributedDataParallel(
            teacher,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )
        student = nn.parallel.DistributedDataParallel(
            student,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )
        teacher_module = teacher.module
        student_module = student.module

        teacher_module.load_state_dict(student_module.state_dict())

    else:
        teacher_module = teacher
        student_module = student

        teacher_module.load_state_dict(student.state_dict())

    for p in teacher.parameters():
        p.requires_grad = False

    grad_accum = conf.training.grad_accumulation

    train_set, valid_set = make_augment_dataset(
        conf.dataset_path,
        DINOAugment(
            conf.task.global_crop_size,
            conf.task.local_crop_size,
            conf.task.global_crop_scale,
            conf.task.local_crop_scale,
            conf.task.n_local_crop,
        ),
        None,
    )

    batch_size = conf.training.dataloader.batch_size // grad_accum

    train_loader, valid_loader, train_sampler = make_dataloader(
        train_set,
        valid_set,
        batch_size,
        conf.distributed,
        conf.training.dataloader.num_workers,
    )

    criterion_train = DINOLoss(
        conf.arch.dim_head_out,
        conf.task.n_local_crop + 2,
        conf.task.warmup_teacher_temperature,
        conf.task.teacher_temperature,
        conf.task.warmup_teacher_temperature_epoch,
        conf.training.epoch,
    ).to(device)

    parameters, names = add_weight_decay(
        student.named_parameters(),
        conf.training.weight_decay,
        wd_skip_fn(conf.training.wd_skip),
    )

    def make_scheduler(train_conf, optimizer, epoch_len):
        warmup = train_conf.scheduler.warmup * epoch_len
        n_iter = epoch_len * train_conf.epoch
        lr = train_conf.base_lr * train_conf.dataloader.batch_size / 256

        if train_conf.scheduler.type == "exp_epoch":
            return train_conf.scheduler.make(optimizer,
                                             epoch_len,
                                             lr=lr,
                                             max_iter=train_conf.epoch,
                                             warmup=warmup)

        else:
            return train_conf.scheduler.make(optimizer,
                                             lr=lr,
                                             n_iter=n_iter,
                                             warmup=warmup)

    optimizer = make_optimizer(conf.training, parameters)
    epoch_len = math.ceil(len(train_loader) / grad_accum)
    scheduler = make_scheduler(conf.training, optimizer, epoch_len)
    wd_schedule = cosine_schedule(
        conf.training.weight_decay,
        conf.task.weight_decay_end,
        epoch_len * conf.training.epoch,
    )
    momentum_schedule = cosine_schedule(conf.task.teacher_momentum, 1,
                                        epoch_len * conf.training.epoch)

    scaler = amp.GradScaler(enabled=conf.fp16)

    checker = conf.checker.make()

    step = 0

    for epoch in range(conf.training.epoch):
        if conf.distributed:
            train_sampler.set_epoch(epoch)

        train(
            conf,
            step,
            epoch,
            train_loader,
            teacher,
            student,
            criterion_train,
            optimizer,
            scheduler,
            wd_schedule,
            momentum_schedule,
            scaler,
            grad_accum,
            checker,
        )
        step += epoch_len

        try:
            checker.checkpoint(
                {
                    "student": student_module.state_dict(),
                    "teacher": teacher_module.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "conf": conf.dict(),
                },
                f"epoch-{str(epoch + 1).zfill(3)}.pt",
            )

        except Exception as e:
            print(e)