def __qanet_adam_optimizer(model, config): optimizer_config = config.get('optimizer', {}) adam_config = load_dataclass(optimizer_config, AdamConf) optimizer = optim.Adam( model.parameters(), lr=adam_config.learning_rate, betas=(adam_config.beta_1, adam_config.beta_2), eps=adam_config.eps, weight_decay=adam_config.weight_decay ) return optimizer
def __qanet_adam_scheduler(optimizer, config, args): scheduler_config = config.get('scheduler', {}) warmup_conf = load_dataclass(scheduler_config, WarmupSchedulerConf) scheduler = sched.LambdaLR( optimizer, lambda batch: min( 1, 1 / math.log(warmup_conf.warmup_steps - 1) * math.log(batch * args.batch_size + 1) ) ) return scheduler
def init_qanet2_training(args, char_vectors, word_vectors, config): from models.qanet2 import QANet, QANetConf model_config = config.get('model', {}) model = QANet( word_mat=word_vectors, char_mat=char_vectors, config=load_dataclass(model_config, QANetConf) ) optimizer = __qanet_adam_optimizer(model, config) scheduler = __qanet_adam_scheduler(optimizer, config, args) return model, optimizer, scheduler
def init_baseline_training(args, word_vectors, config): from models.bidaf import BiDAF model_config = config.get('model', {}) bidaf_conf: BiDAFConf = load_dataclass(model_config, BiDAFConf) model = BiDAF( word_vectors=word_vectors, hidden_size=bidaf_conf.hidden_size, drop_prob=bidaf_conf.drop_prob ) optimizer_config = config.get('optimizer', {}) adadelta_config = load_dataclass(optimizer_config, AdadeltaConf) optimizer = optim.Adadelta( model.parameters(), lr=adadelta_config.learning_rate, weight_decay=adadelta_config.weight_decay ) scheduler_config = config.get('scheduler', {}) multiplicative_scheduler_config: MultiplicativeSchedulerConf = load_dataclass(scheduler_config, MultiplicativeSchedulerConf) scheduler = sched.LambdaLR(optimizer, lambda batch: multiplicative_scheduler_config.multiplier ** ((batch * args.batch_size) // 1000)) return model, optimizer, scheduler