def train(conf, loader, model, ema, diffusion, optimizer, scheduler, device,
          wandb):
    loader = sample_data(loader)

    pbar = range(conf.training.n_iter + 1)

    if dist.is_primary():
        pbar = tqdm(pbar, dynamic_ncols=True)

    for i in pbar:
        epoch, img = next(loader)
        img = img.to(device)
        time = torch.randint(
            0,
            conf.diffusion.beta_schedule["n_timestep"],
            (img.shape[0], ),
            device=device,
        )
        loss = diffusion.p_loss(model, img, time)
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1)
        scheduler.step()
        optimizer.step()

        accumulate(ema, model.module,
                   0 if i < conf.training.scheduler.warmup else 0.9999)

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]
            pbar.set_description(
                f"epoch: {epoch}; loss: {loss.item():.4f}; lr: {lr:.5f}")

            if wandb is not None and i % conf.evaluate.log_every == 0:
                wandb.log({
                    "epoch": epoch,
                    "loss": loss.item(),
                    "lr": lr
                },
                          step=i)

            if i % conf.evaluate.save_every == 0:
                if conf.distributed:
                    model_module = model.module

                else:
                    model_module = model

                torch.save(
                    {
                        "model": model_module.state_dict(),
                        "ema": ema.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "conf": conf,
                    },
                    f"checkpoint/diffusion_{str(i).zfill(6)}.pt",
                )
Example #2
0
def valid(conf, loader, model, criterion):
    device = "cuda"

    batch_time = Meter()
    losses = Meter()
    top1 = Meter()
    top5 = Meter()

    model.eval()

    logger = get_logger(mode=conf.logger)

    start = perf_counter()
    for i, (input, label) in enumerate(loader):
        input = input.to(device)
        label = label.to(device)

        out = model(input)
        loss = criterion(out, label)
        prec1, prec5 = accuracy(out, label, topk=(1, 5))
        batch = input.shape[0]

        loss_dict = {
            "prec1": prec1 * batch,
            "prec5": prec5 * batch,
            "loss": loss * batch,
            "batch": torch.tensor(batch, dtype=torch.float32).to(device),
        }
        loss_reduced = dist.reduce_dict(loss_dict, average=False)
        batch = loss_reduced["batch"].to(torch.int64).item()
        losses.update(loss_reduced["loss"].item() / batch, batch)
        top1.update(loss_reduced["prec1"].item() / batch, batch)
        top5.update(loss_reduced["prec5"].item() / batch, batch)

        batch_time.update(perf_counter() - start)
        start = perf_counter()

        if dist.is_primary() and i % conf.log_freq == 0:
            logger.info(
                f"valid: {i}/{len(loader)}; time: {batch_time.val:.3f} ({batch_time.avg:.3f}); "
                f"loss: {losses.val:.4f} ({losses.avg:.4f}); "
                f"prec@1: {top1.val:.3f} ({top1.avg:.3f}); "
                f"prec@5: {top5.val:.3f} ({top5.avg:.3f})")

    if dist.is_primary():
        logger.info(
            f"validation finished: prec@1 {top1.avg:.3f}, prec@5 {top5.avg:.3f}"
        )

    return top1.avg, top5.avg, losses
Example #3
0
def load_config(config_model, config, overrides=(), show=True):
    conf = config_model(**read_config(config, overrides=overrides))

    if show and is_primary():
        pprint(conf.dict())

    return conf
def main(conf):
    wandb = None
    if dist.is_primary() and conf.evaluate.wandb:
        wandb = load_wandb()
        wandb.init(project="denoising diffusion")

    device = "cuda"
    beta_schedule = "linear"

    conf.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    train_set = MultiResolutionDataset(conf.dataset.path, transform,
                                       conf.dataset.resolution)
    train_sampler = dist.data_sampler(train_set,
                                      shuffle=True,
                                      distributed=conf.distributed)
    train_loader = conf.training.dataloader.make(train_set,
                                                 sampler=train_sampler)

    model = conf.model.make()
    model = model.to(device)
    ema = conf.model.make()
    ema = ema.to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    if conf.ckpt is not None:
        ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage)

        if conf.distributed:
            model.module.load_state_dict(ckpt["model"])

        else:
            model.load_state_dict(ckpt["model"])

        ema.load_state_dict(ckpt["ema"])

    betas = conf.diffusion.beta_schedule.make()
    diffusion = GaussianDiffusion(betas).to(device)

    train(conf, train_loader, model, ema, diffusion, optimizer, scheduler,
          device, wandb)
Example #5
0
def main(conf):
    conf.distributed = False

    device = "cpu"

    if dist.is_primary():
        from pprint import pprint

        pprint(conf.dict())

    with open("trainval_indices.pkl", "rb") as f:
        split_indices = pickle.load(f)

    valid_set = ASRDataset(conf.dataset.path, indices=split_indices["val"])

    valid_sampler = dist.data_sampler(
        valid_set, shuffle=False, distributed=conf.distributed
    )

    valid_loader = conf.training.dataloader.make(
        valid_set, sampler=valid_sampler, collate_fn=collate_data
    )

    model = Transformer(
        conf.dataset.n_vocab,
        conf.model.delta,
        conf.dataset.n_mels,
        conf.model.feature_channel,
        conf.model.dim,
        conf.model.dim_ff,
        conf.model.n_layer,
        conf.model.n_head,
        conf.model.dropout,
    ).to(device)

    if conf.ckpt is not None:
        ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage)

        model_p = model

        if conf.distributed:
            model_p = model.module

        model_p.load_state_dict(ckpt["model"])

    model_valid = ModelValid(model, valid_set, valid_loader, device, None)

    valid(conf, model_valid, 0, block_size=8, max_decode_iter=2)
def train(conf, loader, model, ema, diffusion, optimizer, scheduler, device):
    loader = sample_data(loader)

    pbar = range(conf.training.n_iter + 1)

    for i in pbar:
        epoch, img = next(loader)
        img = img.to(device)
        time = (torch.rand(img.shape[0]) * 1000).type(torch.int64).to(device)
        loss = diffusion.p_loss(model, img, time)
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1)
        scheduler.step()
        optimizer.step()

        accumulate(
            ema, model, 0 if i < conf.training.scheduler.warmup else 0.9999
        )

        if dist.is_primary():
            if i % conf.evaluate.log_every == 0:
                lr = optimizer.param_groups[0]["lr"]
                print(f"epoch: {epoch}; iter: {i}; loss: {loss.item():.4f}; lr: {lr:.6f}")

            if i % conf.evaluate.save_every == 0:
                if conf.distributed:
                    model_module = model.module
                else:
                    model_module = model

                torch.save(
                    {
                        "model": model_module.state_dict(),
                        "ema": ema.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "conf": conf,
                        "epoch": epoch,
                        "iter": i
                    },
                    f"checkpoint/diffusion_{str(i).zfill(6)}.pt",
                )

            if i % conf.evaluate.valid_every == 0 and i > 0:
                grid_im_np = generate_sample(ema)
                grid_im_pil = Image.fromarray(grid_im_np)
                grid_im_pil.save('sample_{str(i).zfill(6)}.png')
Example #7
0
    def catalog(self, conf):
        if not dist.is_primary():
            return

        if not isinstance(conf, dict):
            conf = conf.dict()

        conf = pformat(conf)

        argvs = " ".join([os.path.basename(sys.executable)] + sys.argv)

        template = f"""{argvs}

{conf}"""
        template = template.encode("utf-8")

        for storage in self.storages:
            storage.save(template, "catalog.txt")
def progressive_adaptive_regularization(
    stage,
    max_stage,
    train_sizes,
    valid_sizes,
    randaug_layers,
    randaug_magnitudes,
    mixups,
    cutmixes,
    dropouts,
    drop_paths,
    verbose=True,
):
    train_size = int(lerp(*train_sizes, stage, max_stage))
    valid_size = int(lerp(*valid_sizes, stage, max_stage))
    randaug_layer = int(lerp(*randaug_layers, stage, max_stage))
    randaug_magnitude = lerp(*randaug_magnitudes, stage, max_stage)
    mixup = lerp(*mixups, stage, max_stage)
    cutmix = lerp(*cutmixes, stage, max_stage)
    dropout = lerp(*dropouts, stage, max_stage)
    drop_path = lerp(*drop_paths, stage, max_stage)

    if verbose and dist.is_primary():
        log = f"""Progressive Training with Adaptive Regularization
Stage: {stage + 1} / {max_stage}
Image Size: train={train_size}, valid={valid_size}
RandAugment: n_augment={randaug_layer}, magnitude={randaug_magnitude}
Mixup: {mixup}, Cutmix: {cutmix}, Dropout={dropout}, DropPath={drop_path}"""
        print(log)

    return SimpleNamespace(
        train_size=train_size,
        valid_size=valid_size,
        randaug_layer=randaug_layer,
        randaug_magnitude=randaug_magnitude,
        mixup=mixup,
        cutmix=cutmix,
        dropout=dropout,
        drop_path=drop_path,
    )
Example #9
0
    def __init__(
        self,
        project,
        group=None,
        name=None,
        notes=None,
        resume=None,
        tags=None,
        id=None,
    ):
        if dist.is_primary():
            import wandb

            wandb.init(
                project=project,
                group=group,
                name=name,
                notes=notes,
                resume=resume,
                tags=tags,
                id=id,
            )

            self.wandb = wandb
def train(conf, model_training):
    criterion = ImputerLoss(reduction="mean", zero_infinity=True)

    loader = sample_data(model_training.train_loader)
    pbar = range(conf.training.n_iter + 1)
    if dist.is_primary():
        pbar = tqdm(pbar)

    device = model_training.device
    model = model_training.model
    optimizer = model_training.optimizer
    scheduler = model_training.scheduler
    """for i in pbar:
        epoch, (
            mels,
            token_in,
            targets_ctc,
            targets_ce,
            mask,
            mel_lengths,
            targets_ctc_lengths,
            texts,
            files,
        ) = next(loader)
        mels = mels.to(device)
        token_in = token_in.to(device)
        targets_ctc = targets_ctc.to(device)
        targets_ce = targets_ce.to(device)
        mask = mask.to(device)

        mel_len_reduce = torch.ceil(
            mel_lengths.to(torch.float32) / conf.model.reduction
        ).to(torch.int64)

        out = model(mels, token_in)

        loss = criterion(
            out, targets_ctc, targets_ce, mask, mel_len_reduce, targets_ctc_lengths
        )"""

    for i in pbar:
        epoch, (
            mels,
            token_in,
            targets,
            force_emits,
            mel_lengths,
            token_lengths,
            texts,
            files,
        ) = next(loader)
        mels = mels.to(device)
        token_in = token_in.to(device)
        targets = targets.to(device)
        force_emits = force_emits.to(device)

        mel_len_reduce = torch.ceil(
            mel_lengths.to(torch.float32) / conf.model.reduction).to(
                torch.int64)

        out = torch.log_softmax(model(mels, token_in), 2)

        loss = criterion(
            out.transpose(0, 1).contiguous(),
            targets,
            force_emits,
            mel_len_reduce,
            token_lengths,
        )

        optimizer.zero_grad()
        loss.backward()
        scheduler.step()
        optimizer.step()

        if dist.is_primary() and conf.training.scheduler.type == "lr_find":
            scheduler.record_loss(loss)

        if i % conf.evaluate.log_every == 0:
            loss_dict = {"loss": loss}
            loss_reduced = dist.reduce_dict(loss_dict)
            loss_ctc = loss_reduced["loss"].mean().item()

            lr = optimizer.param_groups[0]["lr"]

        if i > 0 and i % conf.evaluate.valid_every == 0:
            valid(conf, model_training, i)
            valid(conf, model_training, i, block_size=8, max_decode_iter=2)
            valid(conf, model_training, i, block_size=8, max_decode_iter=4)

        if dist.is_primary():
            if i % conf.evaluate.log_every == 0:
                pbar.set_description(
                    f"epoch: {epoch}; loss: {loss_ctc:.4f}; lr: {lr:.5f}")

                if conf.evaluate.wandb and model_training.wandb is not None:
                    model_training.wandb.log(
                        {
                            "training/epoch": epoch,
                            "training/loss": loss_ctc,
                            "training/lr": lr,
                        },
                        step=i,
                    )

            if i > 0 and i % conf.evaluate.save_every == 0:
                if conf.distributed:
                    model_module = model.module

                else:
                    model_module = model

                torch.save(
                    {
                        "model": model_module.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "conf": conf,
                    },
                    f"checkpoint/imputer_{str(i).zfill(6)}.pt",
                )

    if dist.is_primary() and conf.training.scheduler.type == "lr_find":
        scheduler.write_log("loss.log")
Example #11
0
def train(
    conf,
    step,
    epoch,
    loader,
    model,
    model_ema,
    criterion,
    optimizer,
    scheduler,
    scaler,
    grad_accum,
):
    device = "cuda"

    batch_time = Meter()
    data_time = Meter()
    losses = Meter()
    top1 = Meter()
    top5 = Meter()

    model.train()

    agc_params = [
        p[1] for p in model.named_parameters() if "linear" not in p[0]
    ]
    params = list(model.parameters())

    logger = get_logger(mode=conf.logger)

    start = perf_counter()
    for i, (input, label1, label2, ratio) in enumerate(loader):
        # measure data loading time
        input = input.to(device)
        label1 = label1.to(device)
        label2 = label2.to(device)
        ratio = ratio.to(device=device, dtype=torch.float32)
        data_time.update(perf_counter() - start)

        with amp.autocast(enabled=conf.fp16):
            out = model(input)
            loss = criterion(out, label1, label2, ratio) / grad_accum

        prec1, prec5 = accuracy(out, label1, topk=(1, 5))
        batch = input.shape[0]
        losses.update(loss.item() * grad_accum, batch)
        top1.update(prec1.item(), batch)
        top5.update(prec5.item(), batch)

        scaler.scale(loss).backward()

        if ((i + 1) % grad_accum == 0) or (i + 1) == len(loader):
            if conf.training.agc > 0 or conf.training.clip_grad_norm > 0:
                if conf.fp16:
                    scaler.unscale_(optimizer)

                if conf.training.agc > 0:
                    adaptive_grad_clip(agc_params, conf.training.agc)

                if conf.training.clip_grad_norm > 0:
                    nn.utils.clip_grad_norm_(params,
                                             conf.training.clip_grad_norm)

            scheduler.step()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        # optimizer.step()
        t = step + i

        if conf.training.ema > 0:
            if conf.distributed:
                model_module = model.module

            else:
                model_module = model

            accumulate(
                model_ema,
                model_module,
                min(conf.training.ema, (1 + t) / (10 + t)),
                ema_bn=conf.training.ema_bn,
            )

        batch_time.update(perf_counter() - start)
        start = perf_counter()

        if dist.is_primary() and i % conf.log_freq == 0:
            lr = optimizer.param_groups[0]["lr"]

            logger.info(
                f"epoch: {epoch} ({i}/{len(loader)}); time: {batch_time.val:.3f} ({batch_time.avg:.2f}); "
                f"data: {data_time.val:.3f} ({data_time.avg:.2f}); "
                f"loss: {losses.val:.3f} ({losses.avg:.3f}); "
                f"prec@1: {top1.val:.2f} ({top1.avg:.2f}); "
                f"prec@5: {top5.val:.2f} ({top5.avg:.2f})")

    return losses
Example #12
0
 def __del__(self):
     if dist.is_primary():
         self.wandb.finish()
Example #13
0
def main(conf):
    device = "cuda"
    beta_schedule = "linear"
    beta_start = 1e-4
    beta_end = 2e-2
    n_timestep = 1000

    conf.distributed = dist.get_world_size() > 1

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    train_set = MultiResolutionDataset(
        conf.dataset.path, transform, conf.dataset.resolution
    )
    train_sampler = dist.data_sampler(
        train_set, shuffle=True, distributed=conf.distributed
    )
    train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler)

    model = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    model = model.to(device)
    ema = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    ema = ema.to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    betas = make_beta_schedule(beta_schedule, beta_start, beta_end, n_timestep)
    diffusion = GaussianDiffusion(betas).to(device)

    if dist.is_primary():
        bind_model(conf, model, ema, scheduler)

    train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device)
Example #14
0
def valid(conf, model_training, step, block_size=1, max_decode_iter=1):
    criterion = nn.CTCLoss(reduction="mean", zero_infinity=True)
    pbar = model_training.valid_loader

    if dist.is_primary():
        pbar = tqdm(pbar)

    device = model_training.device
    model = model_training.model
    decoder = model_training.dataset.decode

    was_training = model.training
    model.eval()

    dist.synchronize()

    total_dist = 0
    total_length = 0
    show_text = 0
    text_table = []

    for mels, tokens, mel_lengths, token_lengths, texts, _ in pbar:
        mels = mels.to(device)
        tokens = tokens.to(device)
        mel_len_reduce = torch.ceil(
            mel_lengths.to(torch.float32) / conf.model.reduction
        ).to(torch.int64)

        pred_token = None

        for decode_candid in range(0, block_size, block_size // max_decode_iter):
            if decode_candid == 0:
                align_in = tokens.new_ones(
                    mels.shape[0], math.ceil(mels.shape[1] / conf.model.reduction)
                )
                out = None
                mask = None

            else:
                align_in, mask = decode_argmax(
                    out, block_size, block_size // max_decode_iter, mask
                )

            out = torch.log_softmax(model(mels, align_in), 2)

            """if pred_token is None:
                pred_token = out.argmax(2)

            else:
                pred_token = (1 - mask) * out.argmax(2) + mask * pred_token"""
            pred_token = out.argmax(2)

        loss = criterion(
            out.transpose(0, 1).contiguous(), tokens, mel_len_reduce, token_lengths
        )

        pred_token = pred_token.to("cpu").tolist()

        for mel_len, pred_tok, gt in zip(mel_len_reduce.tolist(), pred_token, texts):
            pred = "".join(decoder(ctc_decode(pred_tok[:mel_len])))
            editdist, reflen = char_distance(gt, pred)
            total_dist += editdist
            total_length += reflen

            if dist.is_primary() and show_text < 8:
                pbar.write(f"gt: {gt}\t\ttranscription: {pred}")
                show_text += 1
                text_table.append([gt, pred, str(editdist), str(editdist / reflen)])

        dist.synchronize()

        comm = {
            "loss": loss.item(),
            "total_dist": total_dist,
            "total_length": total_length,
        }
        comm = dist.all_gather(comm)

        part_dist = 0
        part_len = 0
        part_loss = 0
        for eval_parts in comm:
            part_dist += eval_parts["total_dist"]
            part_len += eval_parts["total_length"]
            part_loss += loss

        if dist.is_primary():
            n_part = len(comm)
            cer = part_dist / part_len * 100
            pbar.set_description(f"loss: {part_loss / n_part:.4f}; cer: {cer:.2f}")

        dist.synchronize()

    if dist.is_primary():
        n_part = len(comm)
        cer = part_dist / part_len * 100
        pbar.write(f"loss: {part_loss / n_part:.4f}; cer: {cer:.2f}")

        if conf.evaluate.wandb and model_training.wandb is not None:
            model_training.wandb.log(
                {
                    f"valid/iter-{max_decode_iter}/loss": part_loss / n_part,
                    f"valid/iter-{max_decode_iter}/cer": cer,
                    f"valid/iter-{max_decode_iter}/text": model_training.wandb.Table(
                        data=text_table,
                        columns=["Reference", "Transcription", "Edit Distance", "CER"],
                    ),
                },
                step=step,
            )

    if was_training:
        model.train()
Example #15
0
 def save(self, data, name):
     if dist.is_primary():
         for storage in self.storages:
             storage.save(data, name)
def main(conf):
    conf.distributed = dist.get_world_size() > 1

    device = "cuda"

    if dist.is_primary():
        from pprint import pprint

        pprint(conf.dict())

    if dist.is_primary() and conf.evaluate.wandb:
        wandb = load_wandb()
        wandb.init(project="asr")

    else:
        wandb = None

    with open("trainval_indices.pkl", "rb") as f:
        split_indices = pickle.load(f)

    train_set = ASRDataset(
        conf.dataset.path,
        indices=split_indices["train"],
        alignment=conf.dataset.alignment,
    )
    valid_set = ASRDataset(conf.dataset.path, indices=split_indices["val"])

    train_sampler = dist.data_sampler(train_set,
                                      shuffle=True,
                                      distributed=conf.distributed)
    valid_sampler = dist.data_sampler(valid_set,
                                      shuffle=False,
                                      distributed=conf.distributed)

    if conf.training.batch_sampler is not None:
        train_lens = []

        for i in split_indices["train"]:
            train_lens.append(train_set.mel_lengths[i])

        opts = conf.training.batch_sampler

        bins = ((opts.base**np.linspace(opts.start, 1, 2 * opts.k + 1)) *
                1000).tolist()
        groups, bins, n_samples = create_groups(train_lens, bins)
        batch_sampler = GroupedBatchSampler(
            train_sampler, groups, conf.training.dataloader.batch_size)

        conf.training.dataloader.batch_size = 1
        train_loader = conf.training.dataloader.make(
            train_set,
            batch_sampler=batch_sampler,
            collate_fn=collate_data_imputer)

    else:
        train_loader = conf.training.dataloader.make(
            train_set, collate_fn=collate_data_imputer)

    valid_loader = conf.training.dataloader.make(valid_set,
                                                 sampler=valid_sampler,
                                                 collate_fn=collate_data)

    model = Transformer(
        conf.dataset.n_vocab,
        conf.model.delta,
        conf.dataset.n_mels,
        conf.model.feature_channel,
        conf.model.dim,
        conf.model.dim_ff,
        conf.model.n_layer,
        conf.model.n_head,
        conf.model.dropout,
    ).to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    if conf.ckpt is not None:
        ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage)

        model_p = model

        if conf.distributed:
            model_p = model.module

        model_p.load_state_dict(ckpt["model"])
        # scheduler.load_state_dict(ckpt["scheduler"])

        model_p.copy_embed(1)

    model_training = ModelTraining(
        model,
        optimizer,
        scheduler,
        train_set,
        train_loader,
        valid_loader,
        device,
        wandb,
    )

    train(conf, model_training)
def train(
    conf,
    step,
    epoch,
    loader,
    teacher,
    student,
    criterion,
    optimizer,
    scheduler,
    wd_schedule,
    momentum_schedule,
    scaler,
    grad_accum,
    checker,
):
    device = "cuda"

    batch_time = Meter()
    data_time = Meter()
    losses = Meter()

    student.train()

    agc_params = [
        p[1] for p in student.named_parameters() if "linear" not in p[0]
    ]
    params = list(student.parameters())

    logger = get_logger(mode=conf.logger)

    start = perf_counter()
    for i, (inputs, _) in enumerate(loader):
        # measure data loading time
        inputs = [i.to(device) for i in inputs]
        data_time.update(perf_counter() - start)

        with amp.autocast(enabled=conf.fp16):
            with torch.no_grad():
                teacher_out = teacher(inputs[:2])

            student_out = student(inputs)

            loss = criterion(student_out, teacher_out, epoch) / grad_accum

        losses.update(loss.item() * grad_accum, inputs[0].shape[0])

        scaler.scale(loss).backward()

        for param_group in optimizer.param_groups:
            if "no_decay" not in param_group:
                param_group["weight_decay"] = wd_schedule[step]

        if ((i + 1) % grad_accum == 0) or (i + 1) == len(loader):
            if conf.training.agc > 0 or conf.training.clip_grad_norm > 0:
                if conf.fp16:
                    scaler.unscale_(optimizer)

                if conf.training.agc > 0:
                    adaptive_grad_clip(agc_params, conf.training.agc)

                if conf.training.clip_grad_norm > 0:
                    nn.utils.clip_grad_norm_(params,
                                             conf.training.clip_grad_norm)

            cancel_last_layer_grad(epoch, student, conf.task.freeze_last_layer)

            scheduler.step()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

            with torch.no_grad():
                m = momentum_schedule[step]

                for param_q, param_k in zip(student.parameters(),
                                            teacher.parameters()):
                    param_k.detach().mul_(m).add_(param_q.detach(),
                                                  alpha=1 - m)

        batch_time.update(perf_counter() - start)
        start = perf_counter()

        if dist.is_primary() and i % conf.log_freq == 0:
            lr = optimizer.param_groups[0]["lr"]
            """logger.info(
                f"epoch: {epoch} ({i}/{len(loader)}); time: {batch_time.val:.3f} ({batch_time.avg:.2f}); "
                f"data: {data_time.val:.3f} ({data_time.avg:.2f}); "
                f"loss: {losses.val:.3f} ({losses.avg:.3f}); "
                f"lr: {lr:.5f}; "
                f"wd: {wd_schedule[step]:4f}; "
                f"moment: {momentum_schedule[step]:.4f}"
            )"""

            checker.log(
                step=step,
                weight_decay=wd_schedule[step],
                momentum=momentum_schedule[step],
                loss=losses.avg,
                lr=optimizer.param_groups[0]["lr"],
            )

        step += 1

    return losses
Example #18
0
 def log(self, step=None, **kwargs):
     if dist.is_primary():
         for reporter in self.reporters:
             reporter.log(step, **kwargs)
Example #19
0
 def checkpoint(self, obj, name):
     if dist.is_primary():
         for storage in self.storages:
             storage.checkpoint(obj, name)