def __init__(self, config): super().__init__(config) self.config = config self.device = get_device() set_seeds(self.config.seed) self.current_epoch = 1 self.global_step = 0 self.best_valid_mean_iou = 0 self.model = BERTModel4Pretrain(self.config) self.criterion1 = nn.CrossEntropyLoss(reduction='none') self.criterion2 = nn.CrossEntropyLoss() self.optimizer = optim4GPU(self.config, self.model) self.writer = SummaryWriter(log_dir=self.config.log_dir) tokenizer = FullTokenizer(self.config, do_lower_case=True) train_dataset = SentencePairDataset(self.config, tokenizer, 'train') validate_dataset = SentencePairDataset(self.config, tokenizer, 'validate') a = train_dataset.__getitem__(0) self.train_dataloader = DataLoader( train_dataset, batch_size=self.config.batch_size, num_workers=self.config.data_loader_workers, pin_memory=self.config.pin_memory) self.validate_dataloader = DataLoader( validate_dataset, batch_size=self.config.batch_size, num_workers=self.config.data_loader_workers, pin_memory=self.config.pin_memory) self.model = self.model.to(self.device) if self.config.data_parallel: self.model = nn.DataParallel(self.model) self.load_checkpoint(self.config.checkpoint_to_load)
def main(): # Load Configuration model_cfg = configuration.model.from_json(cfg.model_cfg) # BERT_cfg set_seeds(cfg.seed) # Load Data & Create Criterion #data = load_data(cfg) #if cfg.uda_mode or cfg.mixmatch_mode: # data_iter = [data.sup_data_iter(), data.unsup_data_iter()] if cfg.mode=='train' \ # else [data.sup_data_iter(), data.unsup_data_iter(), data.eval_data_iter()] # train_eval #else: # data_iter = [data.sup_data_iter()] # my own implementation dataset = DataSet(cfg) train_dataset, val_dataset, unsup_dataset = dataset.get_dataset() # Create the DataLoaders for our training and validation sets. train_dataloader = DataLoader( train_dataset, # The training samples. sampler = RandomSampler(train_dataset), # Select batches randomly batch_size = cfg.train_batch_size # Trains with this batch size. ) validation_dataloader = DataLoader( val_dataset, # The validation samples. sampler = SequentialSampler(val_dataset), # Pull out batches sequentially. batch_size = cfg.eval_batch_size # Evaluate with this batch size. ) unsup_dataloader = None if unsup_dataset: unsup_dataloader = DataLoader( unsup_dataset, sampler = RandomSampler(unsup_dataset), batch_size = cfg.train_batch_size ) if cfg.uda_mode or cfg.mixmatch_mode: data_iter = [train_dataloader, unsup_dataloader, validation_dataloader] else: data_iter = [train_dataloader, validation_dataloader] ema_optimizer = None ema_model = None if cfg.model == "custom": model = models.Classifier(model_cfg, NUM_LABELS[cfg.task]) elif cfg.model == "bert": model = BertForSequenceClassificationCustom.from_pretrained( "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. num_labels = NUM_LABELS[cfg.task], output_attentions = False, # Whether the model returns attentions weights. output_hidden_states = False, # Whether the model returns all hidden-states. ) if cfg.uda_mode: if cfg.unsup_criterion == 'KL': unsup_criterion = nn.KLDivLoss(reduction='none') else: unsup_criterion = nn.MSELoss(reduction='none') sup_criterion = nn.CrossEntropyLoss(reduction='none') optimizer = optim.optim4GPU(cfg, model) elif cfg.mixmatch_mode: train_criterion = SemiLoss() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) ema_model = models.Classifier(model_cfg, NUM_LABELS[cfg.task]) for param in ema_model.parameters(): param.detach_() ema_optimizer= WeightEMA(cfg, model, ema_model, alpha=cfg.ema_decay) else: sup_criterion = nn.CrossEntropyLoss(reduction='none') optimizer = optim.optim4GPU(cfg, model) # Create trainer trainer = train.Trainer(cfg, model, data_iter, optimizer, get_device(), ema_model, ema_optimizer) # loss functions def get_sup_loss(model, sup_batch, unsup_batch, global_step): # batch input_ids, segment_ids, input_mask, og_label_ids, num_tokens = sup_batch # convert label ids to hot vectors sup_size = input_ids.size(0) label_ids = torch.zeros(sup_size, 2).scatter_(1, og_label_ids.cpu().view(-1,1), 1) label_ids = label_ids.cuda(non_blocking=True) # sup mixup sup_l = np.random.beta(cfg.alpha, cfg.alpha) sup_l = max(sup_l, 1-sup_l) sup_idx = torch.randperm(sup_size) if cfg.sup_mixup and 'word' in cfg.sup_mixup: if cfg.simple_pad: simple_pad(input_ids, input_mask, num_tokens) c_input_ids = None else: input_ids, c_input_ids = pad_for_word_mixup( input_ids, input_mask, num_tokens, sup_idx ) else: c_input_ids = None # sup loss hidden = model( input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, output_h=True, mixup=cfg.sup_mixup, shuffle_idx=sup_idx, clone_ids=c_input_ids, l=sup_l, manifold_mixup=cfg.manifold_mixup, simple_pad=cfg.simple_pad, no_grad_clone=cfg.no_grad_clone ) logits = model(input_h=hidden) if cfg.sup_mixup: label_ids = mixup_op(label_ids, sup_l, sup_idx) sup_loss = -torch.sum(F.log_softmax(logits, dim=1) * label_ids, dim=1) if cfg.tsa and cfg.tsa != "none": tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1) larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh # prob = exp(log_prob), prob > tsa_threshold # larger_than_threshold = torch.sum( F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids] , dim=-1) > tsa_threshold loss_mask = torch.ones_like(og_label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32)) sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one()) else: sup_loss = torch.mean(sup_loss) return sup_loss, sup_loss, sup_loss, sup_loss def get_loss_ict(model, sup_batch, unsup_batch, global_step): # batch input_ids, segment_ids, input_mask, og_label_ids, num_tokens = sup_batch ori_input_ids, ori_segment_ids, ori_input_mask, \ aug_input_ids, aug_segment_ids, aug_input_mask, \ ori_num_tokens, aug_num_tokens = unsup_batch # convert label ids to hot vectors sup_size = input_ids.size(0) label_ids = torch.zeros(sup_size, 2).scatter_(1, og_label_ids.cpu().view(-1,1), 1) label_ids = label_ids.cuda(non_blocking=True) # sup mixup sup_l = np.random.beta(cfg.alpha, cfg.alpha) sup_l = max(sup_l, 1-sup_l) sup_idx = torch.randperm(sup_size) if cfg.sup_mixup and 'word' in cfg.sup_mixup: if cfg.simple_pad: simple_pad(input_ids, input_mask, num_tokens) c_input_ids = None else: input_ids, c_input_ids = pad_for_word_mixup( input_ids, input_mask, num_tokens, sup_idx ) else: c_input_ids = None # sup loss if cfg.model == "bert": logits = model( input_ids=input_ids, c_input_ids=c_input_ids, attention_mask=input_mask, mixup=cfg.sup_mixup, shuffle_idx=sup_idx, l=sup_l, manifold_mixup = cfg.manifold_mixup, no_pretrained_pool=cfg.no_pretrained_pool ) else: hidden = model( input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, output_h=True, mixup=cfg.sup_mixup, shuffle_idx=sup_idx, clone_ids=c_input_ids, l=sup_l, manifold_mixup=cfg.manifold_mixup, simple_pad=cfg.simple_pad, no_grad_clone=cfg.no_grad_clone ) logits = model(input_h=hidden) if cfg.sup_mixup: label_ids = mixup_op(label_ids, sup_l, sup_idx) sup_loss = -torch.sum(F.log_softmax(logits, dim=1) * label_ids, dim=1) if cfg.tsa and cfg.tsa != "none": tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1) larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh # prob = exp(log_prob), prob > tsa_threshold # larger_than_threshold = torch.sum( F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids] , dim=-1) > tsa_threshold loss_mask = torch.ones_like(og_label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32)) sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one()) else: sup_loss = torch.mean(sup_loss) if cfg.no_unsup_loss: return sup_loss, sup_loss, sup_loss, sup_loss # unsup loss with torch.no_grad(): if cfg.model == "bert": ori_logits = model( input_ids = ori_input_ids, attention_mask = ori_input_mask, no_pretrained_pool=cfg.no_pretrained_pool ) else: ori_logits = model(ori_input_ids, ori_segment_ids, ori_input_mask) ori_prob = F.softmax(ori_logits, dim=-1) # KLdiv target # mixup l = np.random.beta(cfg.alpha, cfg.alpha) l = max(l, 1-l) idx = torch.randperm(hidden.size(0)) if cfg.mixup and 'word' in cfg.mixup: ori_input_ids, c_ori_input_ids = pad_for_word_mixup( ori_input_ids, ori_input_mask, ori_num_tokens, idx ) else: c_ori_input_ids = None #for i in range(0, batch_size): # new_mask = ori_input_mask[i] # new_ids = ori_input_ids[i] # old_ids = c_ori_input_ids[i] # pdb.set_trace() if cfg.model == "bert": logits = model( input_ids=ori_input_ids, c_input_ids=c_ori_input_ids, attention_mask=ori_input_mask, mixup=cfg.mixup, shuffle_idx=idx, l=l, manifold_mixup = cfg.manifold_mixup, no_pretrained_pool=cfg.no_pretrained_pool ) else: hidden = model( input_ids=ori_input_ids, segment_ids=ori_segment_ids, input_mask=ori_input_mask, output_h=True, mixup=cfg.mixup, shuffle_idx=idx, clone_ids=c_ori_input_ids, l=l, manifold_mixup=cfg.manifold_mixup, simple_pad=cfg.simple_pad, no_grad_clone=cfg.no_grad_clone ) logits = model(input_h=hidden) if cfg.mixup: ori_prob = mixup_op(ori_prob, l, idx) probs_u = torch.softmax(logits, dim=1) unsup_loss = torch.mean((probs_u - ori_prob)**2) w = cfg.uda_coeff * sigmoid_rampup(global_step, cfg.consistency_rampup_ends - cfg.consistency_rampup_starts) final_loss = sup_loss + w*unsup_loss return final_loss, sup_loss, unsup_loss, w*unsup_loss # evaluation def get_acc(model, batch): # input_ids, segment_ids, input_mask, label_id, sentence = batch input_ids, segment_ids, input_mask, label_id = batch logits = model(input_ids, segment_ids, input_mask) _, label_pred = logits.max(1) result = (label_pred == label_id).float() accuracy = result.mean() # output_dump.logs(sentence, label_pred, label_id) # output dump return accuracy, result if cfg.mode == 'train': trainer.train(get_loss, None, cfg.model_file, cfg.pretrain_file) if cfg.mode == 'train_eval': if cfg.mixmatch_mode: trainer.train(get_mixmatch_loss_short, get_acc, cfg.model_file, cfg.pretrain_file) elif cfg.uda_test_mode: trainer.train(get_sup_loss, get_acc, cfg.model_file, cfg.pretrain_file) elif cfg.uda_test_mode_two: trainer.train(get_loss_ict, get_acc, cfg.model_file, cfg.pretrain_file) else: trainer.train(get_sup_loss, get_acc, cfg.model_file, cfg.pretrain_file) if cfg.mode == 'eval': results = trainer.eval(get_acc, cfg.model_file, None) total_accuracy = torch.cat(results).mean().item() print('Accuracy :' , total_accuracy)
def main(cfg, model_cfg): # Load Configuration cfg = configuration.params.from_json(cfg) # Train or Eval cfg model_cfg = configuration.model.from_json(model_cfg) # BERT_cfg set_seeds(cfg.seed) # Load Data & Create Criterion data = load_data(cfg) if cfg.uda_mode: unsup_criterion = nn.KLDivLoss(reduction='none') data_iter = [data.sup_data_iter(), data.unsup_data_iter()] if cfg.mode=='train' \ else [data.sup_data_iter(), data.unsup_data_iter(), data.eval_data_iter()] # train_eval else: data_iter = [data.sup_data_iter()] sup_criterion = nn.CrossEntropyLoss(reduction='none') # Load Model model = models.Classifier(model_cfg, len(data.TaskDataset.labels)) # Create trainer trainer = train.Trainer(cfg, model, data_iter, optim.optim4GPU(cfg, model), get_device()) # Training def get_loss(model, sup_batch, unsup_batch, global_step): # logits -> prob(softmax) -> log_prob(log_softmax) # batch input_ids, segment_ids, input_mask, label_ids = sup_batch if unsup_batch: ori_input_ids, ori_segment_ids, ori_input_mask, \ aug_input_ids, aug_segment_ids, aug_input_mask = unsup_batch input_ids = torch.cat((input_ids, aug_input_ids), dim=0) segment_ids = torch.cat((segment_ids, aug_segment_ids), dim=0) input_mask = torch.cat((input_mask, aug_input_mask), dim=0) # logits logits = model(input_ids, segment_ids, input_mask) # sup loss sup_size = label_ids.shape[0] sup_loss = sup_criterion(logits[:sup_size], label_ids) # shape : train_batch_size if cfg.tsa: tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1) larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh # prob = exp(log_prob), prob > tsa_threshold # larger_than_threshold = torch.sum( F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids] , dim=-1) > tsa_threshold loss_mask = torch.ones_like(label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32)) sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one()) else: sup_loss = torch.mean(sup_loss) # unsup loss if unsup_batch: # ori with torch.no_grad(): ori_logits = model(ori_input_ids, ori_segment_ids, ori_input_mask) ori_prob = F.softmax(ori_logits, dim=-1) # KLdiv target # ori_log_prob = F.log_softmax(ori_logits, dim=-1) # confidence-based masking if cfg.uda_confidence_thresh != -1: unsup_loss_mask = torch.max(ori_prob, dim=-1)[0] > cfg.uda_confidence_thresh unsup_loss_mask = unsup_loss_mask.type(torch.float32) else: unsup_loss_mask = torch.ones(len(logits) - sup_size, dtype=torch.float32) unsup_loss_mask = unsup_loss_mask.to(_get_device()) # aug # softmax temperature controlling uda_softmax_temp = cfg.uda_softmax_temp if cfg.uda_softmax_temp > 0 else 1. aug_log_prob = F.log_softmax(logits[sup_size:] / uda_softmax_temp, dim=-1) # KLdiv loss """ nn.KLDivLoss (kl_div) input : log_prob (log_softmax) target : prob (softmax) https://pytorch.org/docs/stable/nn.html unsup_loss is divied by number of unsup_loss_mask it is different from the google UDA official The official unsup_loss is divided by total https://github.com/google-research/uda/blob/master/text/uda.py#L175 """ unsup_loss = torch.sum(unsup_criterion(aug_log_prob, ori_prob), dim=-1) unsup_loss = torch.sum(unsup_loss * unsup_loss_mask, dim=-1) / torch.max(torch.sum(unsup_loss_mask, dim=-1), torch_device_one()) final_loss = sup_loss + cfg.uda_coeff*unsup_loss return final_loss, sup_loss, unsup_loss return sup_loss, None, None # evaluation def get_acc(model, batch): # input_ids, segment_ids, input_mask, label_id, sentence = batch input_ids, segment_ids, input_mask, label_id = batch logits = model(input_ids, segment_ids, input_mask) _, label_pred = logits.max(1) result = (label_pred == label_id).float() accuracy = result.mean() # output_dump.logs(sentence, label_pred, label_id) # output dump return accuracy, result if cfg.mode == 'train': trainer.train(get_loss, None, cfg.model_file, cfg.pretrain_file) if cfg.mode == 'train_eval': trainer.train(get_loss, get_acc, cfg.model_file, cfg.pretrain_file) if cfg.mode == 'eval': results = trainer.eval(get_acc, cfg.model_file, None) total_accuracy = torch.cat(results).mean().item() print('Accuracy :' , total_accuracy)