def __qanet_adam_optimizer(model, config):
    optimizer_config = config.get('optimizer', {})
    adam_config = load_dataclass(optimizer_config, AdamConf)
    optimizer = optim.Adam(
        model.parameters(),
        lr=adam_config.learning_rate,
        betas=(adam_config.beta_1, adam_config.beta_2),
        eps=adam_config.eps,
        weight_decay=adam_config.weight_decay
    )
    return optimizer
def __qanet_adam_scheduler(optimizer, config, args):
    scheduler_config = config.get('scheduler', {})
    warmup_conf = load_dataclass(scheduler_config, WarmupSchedulerConf)
    scheduler = sched.LambdaLR(
        optimizer,
        lambda batch: min(
            1,
            1 / math.log(warmup_conf.warmup_steps - 1) * math.log(batch * args.batch_size + 1)
        )
    )
    return scheduler
def init_qanet2_training(args, char_vectors, word_vectors, config):
    from models.qanet2 import QANet, QANetConf

    model_config = config.get('model', {})
    model = QANet(
        word_mat=word_vectors,
        char_mat=char_vectors,
        config=load_dataclass(model_config, QANetConf)
    )
    optimizer = __qanet_adam_optimizer(model, config)
    scheduler = __qanet_adam_scheduler(optimizer, config, args)
    return model, optimizer, scheduler
def init_baseline_training(args, word_vectors, config):
    from models.bidaf import BiDAF

    model_config = config.get('model', {})
    bidaf_conf: BiDAFConf = load_dataclass(model_config, BiDAFConf)

    model = BiDAF(
        word_vectors=word_vectors, hidden_size=bidaf_conf.hidden_size, drop_prob=bidaf_conf.drop_prob
    )

    optimizer_config = config.get('optimizer', {})
    adadelta_config = load_dataclass(optimizer_config, AdadeltaConf)
    optimizer = optim.Adadelta(
        model.parameters(),
        lr=adadelta_config.learning_rate,
        weight_decay=adadelta_config.weight_decay
    )

    scheduler_config = config.get('scheduler', {})
    multiplicative_scheduler_config: MultiplicativeSchedulerConf = load_dataclass(scheduler_config, MultiplicativeSchedulerConf)
    scheduler = sched.LambdaLR(optimizer, lambda batch: multiplicative_scheduler_config.multiplier ** ((batch * args.batch_size) // 1000))

    return model, optimizer, scheduler