def decodes(self, pred): gen_logits, generated, disc_logits, is_replaced, non_pad = pred gen_pred = gen_logits.argmax(dim=-1) disc_pred = disc_logits > 0 return gen_pred, generated, disc_pred, is_replaced # %% [markdown] # # 5. Train # %% # Generator and Discriminator if c.my_model: gen_hparam['tie_in_out_embedding'] = c.tie_gen_in_out_embedding generator = ModelForGenerator(gen_hparam) discriminator = ModelForDiscriminator(disc_hparam) discriminator.electra.embedding = generator.electra.embedding else: generator = ElectraForMaskedLM(gen_config) discriminator = ElectraForPreTraining(disc_config) discriminator.electra.embeddings = generator.electra.embeddings if c.tie_gen_in_out_embedding: generator.generator_predictions.dense.weight = generator.electra.embeddings.word_embeddings.weight # ELECTRA training loop electra_model = ELECTRAModel(generator, discriminator, hf_tokenizer) electra_loss_func = ELECTRALoss(gen_label_smooth=c.gen_smooth_label, disc_label_smooth=c.disc_smooth_label) # jit (Haven't fiqured out how to make it work) # input_ids, sentA_lenths = dls.one_batch() # masked_inputs, labels, is_mlm_applied = mlm_cb.mask_tokens(input_ids)
def get_glue_learner(task, run_name=None, inference=False): is_wsc_trick = task == "wnli" and c.wsc_trick # Num_epochs if task in ["rte", "stsb"]: num_epochs = 10 else: num_epochs = 3 # Dataloaders dls = glue_dls[task] if isinstance(c.device, str): dls.to(torch.device(c.device)) elif isinstance(c.device, list): dls.to(torch.device("cuda", c.device[0])) else: dls.to(torch.device("cuda:0")) # Load pretrained model if not c.pretrained_checkpoint: discriminator = ElectraForPreTraining.from_pretrained( f"google/electra-{c.size}-discriminator") else: discriminator = (ModelForDiscriminator(hparam) if c.my_model else ElectraForPreTraining(electra_config)) load_part_model(c.pretrained_ckp_path, discriminator, "discriminator") # Seeds & PyTorch benchmark torch.backends.cudnn.benchmark = True if c.seeds: dls[0].rng = random.Random(c.seeds[i]) # for fastai dataloader random.seed(c.seeds[i]) np.random.seed(c.seeds[i]) torch.manual_seed(c.seeds[i]) # Create finetuning model if is_wsc_trick: model = ELECTRAWSCTrickModel(discriminator, hf_tokenizer.pad_token_id) else: model = SentencePredictor(discriminator.electra, electra_config.hidden_size, num_class=NUM_CLASS[task]) # Discriminative learning rates splitter = partial(hf_electra_param_splitter, wsc_trick=is_wsc_trick) layer_lrs = get_layer_lrs( lr=c.lr, decay_rate=c.layer_lr_decay, num_hidden_layers=electra_config.num_hidden_layers, ) # Optimizer if c.adam_bias_correction: opt_func = partial(Adam, eps=1e-6, mom=0.9, sqr_mom=0.999, wd=c.weight_decay) else: opt_func = partial(Adam_no_bias_correction, eps=1e-6, mom=0.9, sqr_mom=0.999, wd=c.weight_decay) # Learner learn = Learner( dls, model, loss_func=LOSS_FUNC[task], opt_func=opt_func, metrics=METRICS[task], splitter=splitter if not inference else trainable_params, lr=layer_lrs if not inference else defaults.lr, path="./checkpoints/glue", model_dir=c.group_name, ) # Multi gpu if isinstance(c.device, list) or c.device is None: learn.create_opt() learn.model = nn.DataParallel(learn.model, device_ids=c.device) # Mixed precision learn.to_native_fp16(init_scale=2.0**14) # Gradient clip learn.add_cb(GradientClipping(1.0)) # Logging # Logging if run_name and not inference: if c.logger == "neptune": neptune.create_experiment(name=run_name, params={ "task": task, **c, **hparam_update }) learn.add_cb(LightNeptuneCallback(False)) elif c.logger == "wandb": wandb_run = wandb.init( name=run_name, project="electra_glue", config={ "task": task, **c, **hparam_update }, reinit=True, ) learn.add_cb(LightWandbCallback(wandb_run)) # Learning rate schedule if c.schedule == "one_cycle": return learn, partial(learn.fit_one_cycle, n_epoch=num_epochs, lr_max=layer_lrs) elif c.schedule == "adjusted_one_cycle": return learn, partial( learn.fit_one_cycle, n_epoch=num_epochs, lr_max=layer_lrs, div=1e5, pct_start=0.1, ) else: lr_shed_func = (linear_warmup_and_then_decay if c.schedule == "separate_linear" else linear_warmup_and_decay) lr_shedule = ParamScheduler({ "lr": partial( lr_shed_func, lr_max=np.array(layer_lrs), warmup_pct=0.1, total_steps=num_epochs * (len(dls.train)), ) }) return learn, partial(learn.fit, n_epoch=num_epochs, cbs=[lr_shedule])
def get_glue_learner(task, run_name=None, inference=False): # Num_epochs if task in ['rte', 'stsb']: num_epochs = 10 else: num_epochs = 3 # Dataloaders dls = glue_dls[task] if isinstance(c.device, str): dls.to(torch.device(c.device)) elif isinstance(c.device, list): dls.to(torch.device('cuda', c.device[0])) else: dls.to(torch.device('cuda:0')) # Load pretrained model if not c.pretrained_checkpoint: discriminator = ElectraForPreTraining.from_pretrained(f"google/electra-{c.size}-discriminator") else: discriminator = ModelForDiscriminator(hparam) if c.my_model else ElectraForPreTraining(electra_config) load_part_model(c.pretrained_ckp_path, discriminator, 'discriminator') # Create finetuning model if task=='wnli' and c.wsc_trick: model = ELECTRAWSCTrickModel(discriminator, hf_tokenizer.pad_token_id) else: model = SentencePredictor(discriminator.electra, electra_config.hidden_size, num_class=NUM_CLASS[task]) # Discriminative learning rates splitter = partial( hf_electra_param_splitter, wsc_trick=(task=='wnli' and c.wsc_trick) ) layer_lrs = get_layer_lrs(lr=c.lr, decay_rate=c.layer_lr_decay, num_hidden_layers=electra_config.num_hidden_layers,) # Optimizer if c.adam_bias_correction: opt_func = partial(Adam, eps=1e-6, mom=0.9, sqr_mom=0.999, wd=0.01) else: opt_func = partial(Adam_no_bias_correction, eps=1e-6, mom=0.9, sqr_mom=0.999, wd=0.01) # Learner learn = Learner(dls, model, loss_func=LOSS_FUNC[task], opt_func=opt_func, metrics=[eval(f'{metric}()') for metric in METRICS[task]], splitter=splitter if not inference else trainable_params, lr=layer_lrs if not inference else defaults.lr, path='./checkpoints', model_dir='glue',) # Multi gpu if isinstance(c.device, list) or c.device is None: learn.model = nn.DataParallel(learn.model, device_ids=c.device) # Gradient clip learn.add_cb(GradientClipping(1.0)) # Logging if run_name and not inference: neptune.create_experiment(name=run_name, params={'task':task, **c, **hparam_update}) learn.add_cb(SimplerNeptuneCallback(False)) # Learning rate schedule if c.schedule == 'one_cycle': return learn, partial(learn.fit_one_cycle, n_epoch=num_epochs, lr_max=layer_lrs) elif c.schedule == 'adjusted_one_cycle': return learn, partial(learn.fit_one_cycle, n_epoch=num_epochs, lr_max=layer_lrs, div=1e5, pct_start=0.1) else: lr_shed_func = linear_warmup_and_then_decay if c.schedule=='separate_linear' else linear_warmup_and_decay lr_shedule = ParamScheduler({'lr': partial(lr_shed_func, lr_max=np.array(layer_lrs), warmup_pct=0.1, total_steps=num_epochs*(len(dls.train)))}) return learn, partial(learn.fit, n_epoch=num_epochs, cbs=[lr_shedule])