def main(conf):
    device = "cuda:0" if torch.cuda.is_available() else 'cpu'
    beta_schedule = "linear"
    beta_start = 1e-4
    beta_end = 2e-2
    n_timestep = 1000

    conf.distributed = dist.get_world_size() > 1

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    train_set = MultiResolutionDataset(
        conf.dataset.path, transform, conf.dataset.resolution
    )
    train_sampler = dist.data_sampler(
        train_set, shuffle=True, distributed=conf.distributed
    )
    train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler)

    model = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    model = model.to(device)
    ema = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    ema = ema.to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    betas = make_beta_schedule(beta_schedule, beta_start, beta_end, n_timestep)
    diffusion = GaussianDiffusion(betas).to(device)

    train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device)
def main(conf):
    wandb = None
    if dist.is_primary() and conf.evaluate.wandb:
        wandb = load_wandb()
        wandb.init(project="denoising diffusion")

    device = "cuda"
    beta_schedule = "linear"

    conf.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    train_set = MultiResolutionDataset(conf.dataset.path, transform,
                                       conf.dataset.resolution)
    train_sampler = dist.data_sampler(train_set,
                                      shuffle=True,
                                      distributed=conf.distributed)
    train_loader = conf.training.dataloader.make(train_set,
                                                 sampler=train_sampler)

    model = conf.model.make()
    model = model.to(device)
    ema = conf.model.make()
    ema = ema.to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    if conf.ckpt is not None:
        ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage)

        if conf.distributed:
            model.module.load_state_dict(ckpt["model"])

        else:
            model.load_state_dict(ckpt["model"])

        ema.load_state_dict(ckpt["ema"])

    betas = conf.diffusion.beta_schedule.make()
    diffusion = GaussianDiffusion(betas).to(device)

    train(conf, train_loader, model, ema, diffusion, optimizer, scheduler,
          device, wandb)
def make_dataloader(train_set, valid_set, batch, distributed, n_worker):
    batch_size = batch // dist.get_world_size()

    train_sampler = dist.data_sampler(train_set, shuffle=True, distributed=distributed)

    train_loader = DataLoader(
        train_set, batch_size=batch_size, sampler=train_sampler, num_workers=n_worker
    )
    valid_loader = DataLoader(
        valid_set,
        batch_size=batch_size,
        sampler=dist.data_sampler(valid_set, shuffle=False, distributed=distributed),
        num_workers=n_worker,
    )

    return train_loader, valid_loader, train_sampler
def main(conf):
    conf.distributed = dist.get_world_size() > 1

    device = "cuda"

    if dist.is_primary():
        from pprint import pprint

        pprint(conf.dict())

    if dist.is_primary() and conf.evaluate.wandb:
        wandb = load_wandb()
        wandb.init(project="asr")

    else:
        wandb = None

    with open("trainval_indices.pkl", "rb") as f:
        split_indices = pickle.load(f)

    train_set = ASRDataset(
        conf.dataset.path,
        indices=split_indices["train"],
        alignment=conf.dataset.alignment,
    )
    valid_set = ASRDataset(conf.dataset.path, indices=split_indices["val"])

    train_sampler = dist.data_sampler(train_set,
                                      shuffle=True,
                                      distributed=conf.distributed)
    valid_sampler = dist.data_sampler(valid_set,
                                      shuffle=False,
                                      distributed=conf.distributed)

    if conf.training.batch_sampler is not None:
        train_lens = []

        for i in split_indices["train"]:
            train_lens.append(train_set.mel_lengths[i])

        opts = conf.training.batch_sampler

        bins = ((opts.base**np.linspace(opts.start, 1, 2 * opts.k + 1)) *
                1000).tolist()
        groups, bins, n_samples = create_groups(train_lens, bins)
        batch_sampler = GroupedBatchSampler(
            train_sampler, groups, conf.training.dataloader.batch_size)

        conf.training.dataloader.batch_size = 1
        train_loader = conf.training.dataloader.make(
            train_set,
            batch_sampler=batch_sampler,
            collate_fn=collate_data_imputer)

    else:
        train_loader = conf.training.dataloader.make(
            train_set, collate_fn=collate_data_imputer)

    valid_loader = conf.training.dataloader.make(valid_set,
                                                 sampler=valid_sampler,
                                                 collate_fn=collate_data)

    model = Transformer(
        conf.dataset.n_vocab,
        conf.model.delta,
        conf.dataset.n_mels,
        conf.model.feature_channel,
        conf.model.dim,
        conf.model.dim_ff,
        conf.model.n_layer,
        conf.model.n_head,
        conf.model.dropout,
    ).to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    if conf.ckpt is not None:
        ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage)

        model_p = model

        if conf.distributed:
            model_p = model.module

        model_p.load_state_dict(ckpt["model"])
        # scheduler.load_state_dict(ckpt["scheduler"])

        model_p.copy_embed(1)

    model_training = ModelTraining(
        model,
        optimizer,
        scheduler,
        train_set,
        train_loader,
        valid_loader,
        device,
        wandb,
    )

    train(conf, model_training)