def _build_trainer(config, model, vocab, train_data, valid_data): optimizer = optim.AdamW(model.parameters(), lr=config.trainer.lr) scheduler = None if config.embedder.name.endswith('bert') or config.embedder.name == 'both': non_bert_params = (param for name, param in model.named_parameters() if not name.startswith('text_field_embedder')) optimizer = optim.AdamW([{ 'params': model.text_field_embedder.parameters(), 'lr': config.trainer.bert_lr }, { 'params': non_bert_params, 'lr': config.trainer.lr }, { 'params': [] }]) scheduler = SlantedTriangular( optimizer=optimizer, num_epochs=config.trainer.num_epochs, num_steps_per_epoch=len(train_data) / config.trainer.batch_size, cut_frac=config.trainer.cut_frac, gradual_unfreezing=config.trainer.gradual_unfreezing, discriminative_fine_tuning=config.trainer. discriminative_fine_tuning) logger.info('Trainable params:') for name, param in model.named_parameters(): if param.requires_grad: logger.info('\t' + name) iterator = BucketIterator(batch_size=config.trainer.batch_size) iterator.index_with(vocab) if torch.cuda.is_available(): cuda_device = 0 model = model.cuda(cuda_device) logger.info('Using cuda') else: cuda_device = -1 logger.info('Using cpu') logger.info('Example batch:') _log_batch(next(iterator(train_data))) if config.embedder.name.endswith('bert') or config.embedder.name == 'both': train_data = _filter_data(train_data, vocab) valid_data = _filter_data(valid_data, vocab) return Trainer(model=model, optimizer=optimizer, iterator=iterator, train_dataset=train_data, validation_dataset=valid_data, validation_metric='+MeanAcc', patience=config.trainer.patience, num_epochs=config.trainer.num_epochs, cuda_device=cuda_device, grad_clipping=5., learning_rate_scheduler=scheduler, serialization_dir=os.path.join(config.data.models_dir, config.model_name), should_log_parameter_statistics=False, should_log_learning_rate=False, num_gradient_accumulation_steps=config.trainer. num_gradient_accumulation_steps)
def _build_trainer(config, model, vocab, train_data, valid_data): optimizer = optim.AdamW(model.parameters(), lr=config.trainer.lr) scheduler = None is_bert_based = any( model.name.endswith('bert') for model in config.embedder.models) is_trainable_elmo_based = any( model.name == 'elmo' and model.params['requires_grad'] for model in config.embedder.models) if is_bert_based or is_trainable_elmo_based: params_list = [] non_pretrained_params = [] if is_bert_based: bert_groups = [ 'transformer_model.embeddings.', 'transformer_model.encoder.layer.0.', 'transformer_model.encoder.layer.1.', 'transformer_model.encoder.layer.2.', 'transformer_model.encoder.layer.3.', 'transformer_model.encoder.layer.4.', 'transformer_model.encoder.layer.5.', 'transformer_model.encoder.layer.6.', 'transformer_model.encoder.layer.7.', 'transformer_model.encoder.layer.8.', 'transformer_model.encoder.layer.9.', 'transformer_model.encoder.layer.10.', 'transformer_model.encoder.layer.11.', 'transformer_model.pooler.' ] bert_group2params = {bg: [] for bg in bert_groups} for name, param in model.named_parameters(): is_bert_layer = False for bg in bert_groups: if bg in name: is_bert_layer = True bert_group2params[bg].append(param) logger.info('Param: %s assigned to %s group', name, bg) break if not is_bert_layer: non_pretrained_params.append(param) logger.info('Param: %s assigned to non_pretrained group', name) for bg in bert_groups: params_list.append({ 'params': bert_group2params[bg], 'lr': config.trainer.bert_lr }) params_list.append({ 'params': non_pretrained_params, 'lr': config.trainer.lr }) params_list.append({'params': []}) elif is_trainable_elmo_based: pretrained_params = [] for name, param in model.named_parameters(): if '_elmo_lstm' in name: logger.info('Pretrained param: %s', name) pretrained_params.append(param) else: logger.info('Non-pretrained param: %s', name) non_pretrained_params.append(param) params_list = [{ 'params': pretrained_params, 'lr': config.trainer.bert_lr }, { 'params': non_pretrained_params, 'lr': config.trainer.lr }, { 'params': [] }] optimizer = optim.AdamW(params_list) scheduler = SlantedTriangular( optimizer=optimizer, num_epochs=config.trainer.num_epochs, num_steps_per_epoch=len(train_data) / config.trainer.batch_size, cut_frac=config.trainer.cut_frac, gradual_unfreezing=config.trainer.gradual_unfreezing, discriminative_fine_tuning=config.trainer. discriminative_fine_tuning) logger.info('Trainable params:') for name, param in model.named_parameters(): if param.requires_grad: logger.info('\t' + name) iterator = BucketIterator(batch_size=config.trainer.batch_size) iterator.index_with(vocab) if torch.cuda.is_available(): cuda_device = 0 model = model.cuda(cuda_device) logger.info('Using cuda') else: cuda_device = -1 logger.info('Using cpu') logger.info('Example batch:') _log_batch(next(iterator(train_data))) if is_bert_based: train_data = _filter_data(train_data, vocab) valid_data = _filter_data(valid_data, vocab) return Trainer(model=model, optimizer=optimizer, iterator=iterator, train_dataset=train_data, validation_dataset=valid_data, validation_metric='+MeanAcc', patience=config.trainer.patience, num_epochs=config.trainer.num_epochs, cuda_device=cuda_device, grad_clipping=5., learning_rate_scheduler=scheduler, serialization_dir=os.path.join(config.data.models_dir, config.model_name), should_log_parameter_statistics=False, should_log_learning_rate=False, num_gradient_accumulation_steps=config.trainer. num_gradient_accumulation_steps)
def _build_trainer(config, model, vocab, train_data, valid_data): optimizer = optim.AdamW(model.parameters(), lr=config.trainer.lr) scheduler = None is_bert_based = any( model.name.endswith('bert') for model in config.embedder.models) is_trainable_elmo_based = any( model.name == 'elmo' and model.params['requires_grad'] for model in config.embedder.models) if is_bert_based or is_trainable_elmo_based: def _is_pretrained_param(name): return 'transformer_model' in name or '_elmo_lstm' in name pretrained_params, non_pretrained_params = [], [] for name, param in model.named_parameters(): if _is_pretrained_param(name): logger.info('Pretrained param: %s', name) pretrained_params.append(param) else: logger.info('Non-pretrained param: %s', name) non_pretrained_params.append(param) optimizer = optim.AdamW([{ 'params': pretrained_params, 'lr': config.trainer.bert_lr }, { 'params': non_pretrained_params, 'lr': config.trainer.lr }, { 'params': [] }]) scheduler = SlantedTriangular( optimizer=optimizer, num_epochs=config.trainer.num_epochs, num_steps_per_epoch=len(train_data) / config.trainer.batch_size, cut_frac=config.trainer.cut_frac, gradual_unfreezing=config.trainer.gradual_unfreezing, discriminative_fine_tuning=config.trainer. discriminative_fine_tuning) logger.info('Trainable params:') for name, param in model.named_parameters(): if param.requires_grad: logger.info('\t' + name) iterator = BucketIterator(batch_size=config.trainer.batch_size) iterator.index_with(vocab) if torch.cuda.is_available(): cuda_device = 0 model = model.cuda(cuda_device) logger.info('Using cuda') else: cuda_device = -1 logger.info('Using cpu') logger.info('Example batch:') _log_batch(next(iterator(train_data))) if is_bert_based: train_data = _filter_data(train_data, vocab) valid_data = _filter_data(valid_data, vocab) return Trainer(model=model, optimizer=optimizer, iterator=iterator, train_dataset=train_data, validation_dataset=valid_data, validation_metric='+MeanAcc', patience=config.trainer.patience, num_epochs=config.trainer.num_epochs, cuda_device=cuda_device, grad_clipping=5., learning_rate_scheduler=scheduler, serialization_dir=os.path.join(config.data.models_dir, config.model_name), should_log_parameter_statistics=False, should_log_learning_rate=False, num_gradient_accumulation_steps=config.trainer. num_gradient_accumulation_steps)
def train(self): if self.config.adjust_point: ram_set_flag("adjust_point") # ram_write('dist_reg', self.config.dist_reg) read_hyper_ = partial(read_hyper, self.config.task_id, self.config.arch) num_epochs = int(read_hyper_("num_epochs")) batch_size = int(read_hyper_("batch_size")) logger.info(f"num_epochs: {num_epochs}, batch_size: {batch_size}") if self.config.model_name == 'tmp': p = pathlib.Path('saved/models/tmp') if p.exists(): shutil.rmtree(p) # Maybe we will do some data augmentation here. if self.config.aug_data != '': log(f'Augment data from {self.config.aug_data}') aug_data = auto_create( f"{self.config.task_id}.{self.config.arch}.aug", lambda: self.reader.read(self.config.aug_data), cache=True) self.train_data.instances.extend(aug_data.instances) # Set up the adversarial training policy if self.config.arch == 'bert': model_vocab = embed_util.get_bert_vocab() else: model_vocab = self.vocab # yapf: disable adv_field = 'sent2' if is_sentence_pair(self.config.task_id) and self.config.arch != 'bert' else 'sent' policy_args = { "adv_iteration": self.config.adv_iter, "replace_num": self.config.adv_replace_num, "searcher": WordIndexSearcher( CachedWordSearcher( "external_data/ibp-nbrs.json" if not self.config.big_nbrs else "external_data/euc-top8.json", model_vocab.get_token_to_index_vocabulary("tokens"), second_order=False ), word2idx=model_vocab.get_token_index, idx2word=model_vocab.get_token_from_index, ), 'adv_field': adv_field } # yapf: enable if self.config.adv_policy == 'hot': if is_sentence_pair( self.config.task_id) and self.config.arch != 'bert': policy_args['forward_order'] = 1 adv_policy = adv_utils.HotFlipPolicy(**policy_args) elif self.config.adv_policy == 'rdm': adv_policy = adv_utils.RandomNeighbourPolicy(**policy_args) elif self.config.adv_policy == 'diy': adv_policy = adv_utils.DoItYourselfPolicy(self.config.adv_iter, adv_field, self.config.adv_step) else: adv_policy = adv_utils.NoPolicy # A collate_fn will do some transformation an instance before # fed into a model. If we want to train a model with some transformations # such as cropping/DAE, we can modify code here. e.g., # collate_fn = partial(transform_collate, self.vocab, self.reader, Crop(0.3)) collate_fn = allennlp_collate train_data_sampler = BucketBatchSampler( data_source=self.train_data, batch_size=batch_size, ) # Set callbacks if self.config.task_id == 'SNLI' and self.config.arch != 'bert': epoch_callbacks = [] if self.config.model_pretrain != "": epoch_callbacks = [WarmupCallback(2)] if self.config.model_pretrain == 'auto': self.config.model_pretrain = { "biboe": "SNLI-fix-biboe-sum", "datt": "SNLI-fix-datt" }[self.config.arch] logger.warning( f"Try loading weights from pretrained model {self.config.model_pretrain}" ) pretrain_ckpter = CheckpointerX( f"saved/models/{self.config.model_pretrain}") self.model.load_state_dict(pretrain_ckpter.best_model_state()) else: epoch_callbacks = [] # epoch_callbacks = [] batch_callbacks = [] opt = self.model.get_optimizer() if self.config.arch == 'bert': scl = SlantedTriangular(opt, num_epochs, len(self.train_data) // batch_size) else: scl = None trainer = AdvTrainer( model=self.model, optimizer=opt, learning_rate_scheduler=scl, validation_metric='+accuracy', adv_policy=adv_policy, data_loader=DataLoader( self.train_data, batch_sampler=train_data_sampler, collate_fn=collate_fn, ), validation_data_loader=DataLoader( self.dev_data, batch_size=batch_size, ), num_epochs=num_epochs, patience=None, grad_clipping=1., cuda_device=0, epoch_callbacks=epoch_callbacks, batch_callbacks=batch_callbacks, serialization_dir=f'saved/models/{self.config.model_name}', num_serialized_models_to_keep=20) trainer.train()