def main():
    r"""Main function."""
    # Load command line arguments.
    args = parse_arg()

    # Load model configuration.
    model_cfg = load_cfg(exp_name=args.exp_name)

    # Control random seed.
    set_seed(model_cfg['seed'])

    # Load encoder tokenizer and its configuration.
    enc_tknzr_cfg = load_cfg(exp_name=model_cfg['enc_tknzr_exp'])
    enc_tknzr = TKNZR_OPTS[enc_tknzr_cfg['tknzr_name']].load(cfg=enc_tknzr_cfg)

    # Load decoder tokenizer and its configuration.
    dec_tknzr_cfg = load_cfg(exp_name=model_cfg['dec_tknzr_exp'])
    dec_tknzr = TKNZR_OPTS[dec_tknzr_cfg['tknzr_name']].load(cfg=dec_tknzr_cfg)

    # Get model running device.
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    # Load model.
    model = MODEL_OPTS[model_cfg['model_name']](
        dec_tknzr_cfg=dec_tknzr_cfg,
        enc_tknzr_cfg=enc_tknzr_cfg,
        model_cfg=model_cfg,
    )
    model = load_model_from_ckpt(
        ckpt=args.ckpt,
        exp_name=args.exp_name,
        model=model,
    )
    model.eval()
    model = model.to(device)

    # Load inference method.
    infr = INFR_OPTS[args.infr_name](**args.__dict__)

    # Output inference result.
    print(
        infr.gen(
            batch_text=[args.src],
            dec_max_len=model_cfg['dec_max_len'],
            dec_tknzr=dec_tknzr,
            device=device,
            enc_max_len=model_cfg['enc_max_len'],
            enc_tknzr=enc_tknzr,
            model=model,
        )[0])
def main():
    r"""Main function."""
    # Load command line arguments.
    args = parse_arg()

    # Control random seed.
    set_seed(args.seed)

    # Load encoder tokenizer and its configuration.
    enc_tknzr_cfg = load_cfg(exp_name=args.enc_tknzr_exp)
    enc_tknzr = TKNZR_OPTS[enc_tknzr_cfg['tknzr_name']].load(cfg=enc_tknzr_cfg)

    # Load decoder tokenizer and its configuration.
    dec_tknzr_cfg = load_cfg(exp_name=args.dec_tknzr_exp)
    dec_tknzr = TKNZR_OPTS[dec_tknzr_cfg['tknzr_name']].load(cfg=dec_tknzr_cfg)

    # Load training datset and create dataloader.
    dset = DSET_OPTS[args.dset_name]()
    dldr = torch.utils.data.DataLoader(
        dataset=dset,
        batch_size=args.batch_size,
        shuffle=True,
    )

    # Get model running device.
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    # Create model.
    model = MODEL_OPTS[args.model_name](
        dec_tknzr_cfg=dec_tknzr_cfg,
        enc_tknzr_cfg=enc_tknzr_cfg,
        model_cfg=args.__dict__,
    )
    model.train()
    model = model.to(device)

    # Create optimizer.
    optim = torch.optim.Adam(
        params=model.parameters(),
        lr=args.lr,
    )

    # Create objective function.
    objtv = torch.nn.CrossEntropyLoss()

    # Save model configuration.
    save_cfg(cfg=args.__dict__, exp_name=args.exp_name)

    # Global step.
    step = 0

    # Create experiment folder.
    exp_path = os.path.join(EXP_PATH, args.exp_name)

    if not os.path.exists(exp_path):
        os.makedirs(exp_path)

    # Create logger and log folder.
    writer = torch.utils.tensorboard.SummaryWriter(
        os.path.join(EXP_PATH, 'log', args.exp_name))

    # Log average loss.
    total_loss = 0.0
    pre_total_loss = 0.0

    for cur_epoch in range(args.epoch):
        tqdm_dldr = tqdm(
            dldr, desc=f'epoch: {cur_epoch}, loss: {pre_total_loss:.6f}')
        for batch in tqdm_dldr:
            src, src_len = enc_tknzr.batch_enc(
                batch_text=batch[0],
                max_len=args.enc_max_len,
            )
            tgt, tgt_len = dec_tknzr.batch_enc(
                batch_text=batch[1],
                max_len=args.dec_max_len,
            )
            src = torch.tensor(src).to(device)
            src_len = torch.tensor(src_len).to(device)
            tgt = torch.tensor(tgt).to(device)
            tgt_len = torch.tensor(tgt_len).to(device)

            # Forward pass.
            logits = model(
                src=src,
                src_len=src_len,
                tgt=tgt[:, :-1],
                tgt_len=tgt_len - 1,
            )

            # Calculate loss.
            loss = objtv(
                logits.reshape(-1, dec_tknzr_cfg['n_vocab']),
                tgt[:, 1:].reshape(-1),
            )

            # Accumulate loss.
            total_loss += loss.item() / args.ckpt_step

            # Backward pass.
            loss.backward()

            # Perform gradient clipping.
            torch.nn.utils.clip_grad_norm_(
                parameters=model.parameters(),
                max_norm=args.max_norm,
            )

            # Gradient descent.
            optim.step()

            # Clean up gradient.
            optim.zero_grad()

            # Increment global step.
            step += 1

            # Save checkpoint for each `ckpt_step`.
            if step % args.ckpt_step == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(exp_path, f'model-{step}.pt'),
                )

            if step % args.log_step == 0:
                # Log average loss on CLI.
                tqdm_dldr.set_description(
                    f'epoch: {cur_epoch}, loss: {total_loss:.6f}')

                # Log average loss on tensorboard.
                writer.add_scalar('loss', total_loss, step)

                # Clean up average loss.
                pre_total_loss = total_loss
                total_loss = 0.0

    # Save last checkpoint.
    torch.save(
        model.state_dict(),
        os.path.join(exp_path, f'model-{step}.pt'),
    )

    # Close logger.
    writer.close()
def main():
    r"""Main function."""
    # Load command line arguments.
    args = parse_arg()

    # Load model configuration.
    model_cfg = load_cfg(exp_name=args.exp_name)

    # Control random seed.
    set_seed(model_cfg['seed'])

    # Load encoder tokenizer and its configuration.
    enc_tknzr_cfg = load_cfg(exp_name=model_cfg['enc_tknzr_exp'])
    enc_tknzr = TKNZR_OPTS[enc_tknzr_cfg['tknzr_name']].load(cfg=enc_tknzr_cfg)

    # Load decoder tokenizer and its configuration.
    dec_tknzr_cfg = load_cfg(exp_name=model_cfg['dec_tknzr_exp'])
    dec_tknzr = TKNZR_OPTS[dec_tknzr_cfg['tknzr_name']].load(cfg=dec_tknzr_cfg)

    # Load evaluation datset and create dataloader.
    dset = DSET_OPTS[args.dset_name]()
    dldr = torch.utils.data.DataLoader(
        dataset=dset,
        batch_size=args.batch_size,
        shuffle=False,
    )

    # Get model running device.
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    # Load model.
    model = MODEL_OPTS[model_cfg['model_name']](
        dec_tknzr_cfg=dec_tknzr_cfg,
        enc_tknzr_cfg=enc_tknzr_cfg,
        model_cfg=model_cfg,
    )
    model = load_model_from_ckpt(
        ckpt=args.ckpt,
        exp_name=args.exp_name,
        model=model,
    )
    model.eval()
    model = model.to(device)

    # Load inference method.
    infr = INFR_OPTS[args.infr_name](**args.__dict__)

    # Record batch inference result.
    all_pred = []
    for batch in tqdm(dldr):
        all_pred.extend(
            infr.gen(
                batch_text=batch[0],
                dec_max_len=model_cfg['dec_max_len'],
                dec_tknzr=dec_tknzr,
                device=device,
                enc_max_len=model_cfg['enc_max_len'],
                enc_tknzr=enc_tknzr,
                model=model,
            ))

    # Output all dataset result.
    print(DSET_OPTS[args.dset_name].batch_eval(
        batch_tgt=dset.all_tgt(),
        batch_pred=all_pred,
    ))