Exemplo n.º 1
0
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.device = get_device()

        set_seeds(self.config.seed)
        self.current_epoch = 1
        self.global_step = 0
        self.best_valid_mean_iou = 0

        self.model = BERTModel4Pretrain(self.config)

        self.criterion1 = nn.CrossEntropyLoss(reduction='none')
        self.criterion2 = nn.CrossEntropyLoss()

        self.optimizer = optim4GPU(self.config, self.model)
        self.writer = SummaryWriter(log_dir=self.config.log_dir)

        tokenizer = FullTokenizer(self.config, do_lower_case=True)
        train_dataset = SentencePairDataset(self.config, tokenizer, 'train')
        validate_dataset = SentencePairDataset(self.config, tokenizer,
                                               'validate')

        a = train_dataset.__getitem__(0)

        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            num_workers=self.config.data_loader_workers,
            pin_memory=self.config.pin_memory)

        self.validate_dataloader = DataLoader(
            validate_dataset,
            batch_size=self.config.batch_size,
            num_workers=self.config.data_loader_workers,
            pin_memory=self.config.pin_memory)

        self.model = self.model.to(self.device)
        if self.config.data_parallel:
            self.model = nn.DataParallel(self.model)
        self.load_checkpoint(self.config.checkpoint_to_load)
Exemplo n.º 2
0
def main():
    # Load Configuration
    model_cfg = configuration.model.from_json(cfg.model_cfg)        # BERT_cfg
    set_seeds(cfg.seed)

    # Load Data & Create Criterion
    #data = load_data(cfg)

    #if cfg.uda_mode or cfg.mixmatch_mode:
    #    data_iter = [data.sup_data_iter(), data.unsup_data_iter()] if cfg.mode=='train' \
    #        else [data.sup_data_iter(), data.unsup_data_iter(), data.eval_data_iter()]  # train_eval
    #else:
    #    data_iter = [data.sup_data_iter()]

    # my own implementation
    dataset = DataSet(cfg)
    train_dataset, val_dataset, unsup_dataset = dataset.get_dataset()

    # Create the DataLoaders for our training and validation sets.
    train_dataloader = DataLoader(
                train_dataset,  # The training samples.
                sampler = RandomSampler(train_dataset), # Select batches randomly
                batch_size = cfg.train_batch_size # Trains with this batch size.
            )

    validation_dataloader = DataLoader(
                val_dataset, # The validation samples.
                sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.
                batch_size = cfg.eval_batch_size # Evaluate with this batch size.
            )

    unsup_dataloader = None
    if unsup_dataset:
        unsup_dataloader = DataLoader(
            unsup_dataset,
            sampler = RandomSampler(unsup_dataset),
            batch_size = cfg.train_batch_size
        )

    if cfg.uda_mode or cfg.mixmatch_mode:
        data_iter = [train_dataloader, unsup_dataloader, validation_dataloader] 
    else:
        data_iter = [train_dataloader, validation_dataloader]

    ema_optimizer = None
    ema_model = None

    if cfg.model == "custom":
        model = models.Classifier(model_cfg, NUM_LABELS[cfg.task])
    elif cfg.model == "bert":
        model = BertForSequenceClassificationCustom.from_pretrained(
            "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
            num_labels = NUM_LABELS[cfg.task],
            output_attentions = False, # Whether the model returns attentions weights.
            output_hidden_states = False, # Whether the model returns all hidden-states.
        )


    if cfg.uda_mode:
        if cfg.unsup_criterion == 'KL':
            unsup_criterion = nn.KLDivLoss(reduction='none')
        else:
            unsup_criterion = nn.MSELoss(reduction='none')
        sup_criterion = nn.CrossEntropyLoss(reduction='none')
        optimizer = optim.optim4GPU(cfg, model)
    elif cfg.mixmatch_mode:
        train_criterion = SemiLoss()
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
        ema_model = models.Classifier(model_cfg,  NUM_LABELS[cfg.task])
        for param in ema_model.parameters():
            param.detach_()
        ema_optimizer= WeightEMA(cfg, model, ema_model, alpha=cfg.ema_decay)
    else:
        sup_criterion = nn.CrossEntropyLoss(reduction='none')
        optimizer = optim.optim4GPU(cfg, model)
    
    # Create trainer
    trainer = train.Trainer(cfg, model, data_iter, optimizer, get_device(), ema_model, ema_optimizer)

    # loss functions
    def get_sup_loss(model, sup_batch, unsup_batch, global_step):
        # batch
        input_ids, segment_ids, input_mask, og_label_ids, num_tokens = sup_batch

        # convert label ids to hot vectors
        sup_size = input_ids.size(0)
        label_ids = torch.zeros(sup_size, 2).scatter_(1, og_label_ids.cpu().view(-1,1), 1)
        label_ids = label_ids.cuda(non_blocking=True)

        # sup mixup
        sup_l = np.random.beta(cfg.alpha, cfg.alpha)
        sup_l = max(sup_l, 1-sup_l)
        sup_idx = torch.randperm(sup_size)

        if cfg.sup_mixup and 'word' in cfg.sup_mixup:
            if cfg.simple_pad:
                simple_pad(input_ids, input_mask, num_tokens)
                c_input_ids = None
            else:
                input_ids, c_input_ids = pad_for_word_mixup(
                    input_ids, input_mask, num_tokens, sup_idx
                )
        else:
            c_input_ids = None

        # sup loss
        hidden = model(
            input_ids=input_ids, 
            segment_ids=segment_ids, 
            input_mask=input_mask,
            output_h=True,
            mixup=cfg.sup_mixup,
            shuffle_idx=sup_idx,
            clone_ids=c_input_ids,
            l=sup_l,
            manifold_mixup=cfg.manifold_mixup,
            simple_pad=cfg.simple_pad,
            no_grad_clone=cfg.no_grad_clone
        )
        logits = model(input_h=hidden)

        if cfg.sup_mixup:
            label_ids = mixup_op(label_ids, sup_l, sup_idx)

        sup_loss = -torch.sum(F.log_softmax(logits, dim=1) * label_ids, dim=1)

        if cfg.tsa and cfg.tsa != "none":
            tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1)
            larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh   # prob = exp(log_prob), prob > tsa_threshold
            # larger_than_threshold = torch.sum(  F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids]  , dim=-1) > tsa_threshold
            loss_mask = torch.ones_like(og_label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32))
            sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one())
        else:
            sup_loss = torch.mean(sup_loss)

        return sup_loss, sup_loss, sup_loss, sup_loss


    def get_loss_ict(model, sup_batch, unsup_batch, global_step):
        # batch
        input_ids, segment_ids, input_mask, og_label_ids, num_tokens = sup_batch
        ori_input_ids, ori_segment_ids, ori_input_mask, \
        aug_input_ids, aug_segment_ids, aug_input_mask, \
        ori_num_tokens, aug_num_tokens = unsup_batch

        # convert label ids to hot vectors
        sup_size = input_ids.size(0)
        label_ids = torch.zeros(sup_size, 2).scatter_(1, og_label_ids.cpu().view(-1,1), 1)
        label_ids = label_ids.cuda(non_blocking=True)

        # sup mixup
        sup_l = np.random.beta(cfg.alpha, cfg.alpha)
        sup_l = max(sup_l, 1-sup_l)
        sup_idx = torch.randperm(sup_size)

        if cfg.sup_mixup and 'word' in cfg.sup_mixup:
            if cfg.simple_pad:
                simple_pad(input_ids, input_mask, num_tokens)
                c_input_ids = None
            else:
                input_ids, c_input_ids = pad_for_word_mixup(
                    input_ids, input_mask, num_tokens, sup_idx
                )
        else:
            c_input_ids = None

        # sup loss
        if cfg.model == "bert":
            logits = model(
                input_ids=input_ids,
                c_input_ids=c_input_ids,
                attention_mask=input_mask,
                mixup=cfg.sup_mixup,
                shuffle_idx=sup_idx,
                l=sup_l,
                manifold_mixup = cfg.manifold_mixup,
                no_pretrained_pool=cfg.no_pretrained_pool
            )
        else:
            hidden = model(
                input_ids=input_ids, 
                segment_ids=segment_ids, 
                input_mask=input_mask,
                output_h=True,
                mixup=cfg.sup_mixup,
                shuffle_idx=sup_idx,
                clone_ids=c_input_ids,
                l=sup_l,
                manifold_mixup=cfg.manifold_mixup,
                simple_pad=cfg.simple_pad,
                no_grad_clone=cfg.no_grad_clone
            )
            logits = model(input_h=hidden)

        if cfg.sup_mixup:
            label_ids = mixup_op(label_ids, sup_l, sup_idx)

        sup_loss = -torch.sum(F.log_softmax(logits, dim=1) * label_ids, dim=1)

        if cfg.tsa and cfg.tsa != "none":
            tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1)
            larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh   # prob = exp(log_prob), prob > tsa_threshold
            # larger_than_threshold = torch.sum(  F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids]  , dim=-1) > tsa_threshold
            loss_mask = torch.ones_like(og_label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32))
            sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one())
        else:
            sup_loss = torch.mean(sup_loss)

        if cfg.no_unsup_loss:
            return sup_loss, sup_loss, sup_loss, sup_loss

        # unsup loss
        with torch.no_grad():
            if cfg.model == "bert":
                ori_logits = model(
                    input_ids = ori_input_ids,
                    attention_mask = ori_input_mask,
                    no_pretrained_pool=cfg.no_pretrained_pool
                )
            else:
                ori_logits = model(ori_input_ids, ori_segment_ids, ori_input_mask)
            ori_prob   = F.softmax(ori_logits, dim=-1)    # KLdiv target


        # mixup
        l = np.random.beta(cfg.alpha, cfg.alpha)
        l = max(l, 1-l)
        idx = torch.randperm(hidden.size(0))

        
        if cfg.mixup and 'word' in cfg.mixup:
            ori_input_ids, c_ori_input_ids = pad_for_word_mixup(
                ori_input_ids, ori_input_mask, ori_num_tokens, idx
            )
        else:
            c_ori_input_ids = None

        
        #for i in range(0, batch_size):
        #    new_mask = ori_input_mask[i]
        #    new_ids = ori_input_ids[i]
        #    old_ids = c_ori_input_ids[i]
        #    pdb.set_trace()
        if cfg.model == "bert":
            logits = model(
                input_ids=ori_input_ids,
                c_input_ids=c_ori_input_ids,
                attention_mask=ori_input_mask,
                mixup=cfg.mixup,
                shuffle_idx=idx,
                l=l,
                manifold_mixup = cfg.manifold_mixup,
                no_pretrained_pool=cfg.no_pretrained_pool
            )
        else:
            hidden = model(
                input_ids=ori_input_ids, 
                segment_ids=ori_segment_ids, 
                input_mask=ori_input_mask,
                output_h=True,
                mixup=cfg.mixup,
                shuffle_idx=idx,
                clone_ids=c_ori_input_ids,
                l=l,
                manifold_mixup=cfg.manifold_mixup,
                simple_pad=cfg.simple_pad,
                no_grad_clone=cfg.no_grad_clone
            )
            logits = model(input_h=hidden)

        if cfg.mixup:
            ori_prob = mixup_op(ori_prob, l, idx)

        probs_u = torch.softmax(logits, dim=1)
        unsup_loss = torch.mean((probs_u - ori_prob)**2)

        w = cfg.uda_coeff * sigmoid_rampup(global_step, cfg.consistency_rampup_ends - cfg.consistency_rampup_starts)
        final_loss = sup_loss + w*unsup_loss
        return final_loss, sup_loss, unsup_loss, w*unsup_loss

    # evaluation
    def get_acc(model, batch):
        # input_ids, segment_ids, input_mask, label_id, sentence = batch
        input_ids, segment_ids, input_mask, label_id = batch
        logits = model(input_ids, segment_ids, input_mask)
        _, label_pred = logits.max(1)

        result = (label_pred == label_id).float()
        accuracy = result.mean()
        # output_dump.logs(sentence, label_pred, label_id)    # output dump

        return accuracy, result

    if cfg.mode == 'train':
        trainer.train(get_loss, None, cfg.model_file, cfg.pretrain_file)

    if cfg.mode == 'train_eval':
        if cfg.mixmatch_mode:
            trainer.train(get_mixmatch_loss_short, get_acc, cfg.model_file, cfg.pretrain_file)
        elif cfg.uda_test_mode:
            trainer.train(get_sup_loss, get_acc, cfg.model_file, cfg.pretrain_file)
        elif cfg.uda_test_mode_two:
            trainer.train(get_loss_ict, get_acc, cfg.model_file, cfg.pretrain_file)
        else:
            trainer.train(get_sup_loss, get_acc, cfg.model_file, cfg.pretrain_file)

    if cfg.mode == 'eval':
        results = trainer.eval(get_acc, cfg.model_file, None)
        total_accuracy = torch.cat(results).mean().item()
        print('Accuracy :' , total_accuracy)
Exemplo n.º 3
0
def main(cfg, model_cfg):
    # Load Configuration
    cfg = configuration.params.from_json(cfg)                   # Train or Eval cfg
    model_cfg = configuration.model.from_json(model_cfg)        # BERT_cfg
    set_seeds(cfg.seed)

    # Load Data & Create Criterion
    data = load_data(cfg)
    if cfg.uda_mode:
        unsup_criterion = nn.KLDivLoss(reduction='none')
        data_iter = [data.sup_data_iter(), data.unsup_data_iter()] if cfg.mode=='train' \
            else [data.sup_data_iter(), data.unsup_data_iter(), data.eval_data_iter()]  # train_eval
    else:
        data_iter = [data.sup_data_iter()]
    sup_criterion = nn.CrossEntropyLoss(reduction='none')
    
    # Load Model
    model = models.Classifier(model_cfg, len(data.TaskDataset.labels))

    # Create trainer
    trainer = train.Trainer(cfg, model, data_iter, optim.optim4GPU(cfg, model), get_device())

    # Training
    def get_loss(model, sup_batch, unsup_batch, global_step):

        # logits -> prob(softmax) -> log_prob(log_softmax)

        # batch
        input_ids, segment_ids, input_mask, label_ids = sup_batch
        if unsup_batch:
            ori_input_ids, ori_segment_ids, ori_input_mask, \
            aug_input_ids, aug_segment_ids, aug_input_mask  = unsup_batch

            input_ids = torch.cat((input_ids, aug_input_ids), dim=0)
            segment_ids = torch.cat((segment_ids, aug_segment_ids), dim=0)
            input_mask = torch.cat((input_mask, aug_input_mask), dim=0)
            
        # logits
        logits = model(input_ids, segment_ids, input_mask)

        # sup loss
        sup_size = label_ids.shape[0]            
        sup_loss = sup_criterion(logits[:sup_size], label_ids)  # shape : train_batch_size
        if cfg.tsa:
            tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1)
            larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh   # prob = exp(log_prob), prob > tsa_threshold
            # larger_than_threshold = torch.sum(  F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids]  , dim=-1) > tsa_threshold
            loss_mask = torch.ones_like(label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32))
            sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one())
        else:
            sup_loss = torch.mean(sup_loss)

        # unsup loss
        if unsup_batch:
            # ori
            with torch.no_grad():
                ori_logits = model(ori_input_ids, ori_segment_ids, ori_input_mask)
                ori_prob   = F.softmax(ori_logits, dim=-1)    # KLdiv target
                # ori_log_prob = F.log_softmax(ori_logits, dim=-1)

                # confidence-based masking
                if cfg.uda_confidence_thresh != -1:
                    unsup_loss_mask = torch.max(ori_prob, dim=-1)[0] > cfg.uda_confidence_thresh
                    unsup_loss_mask = unsup_loss_mask.type(torch.float32)
                else:
                    unsup_loss_mask = torch.ones(len(logits) - sup_size, dtype=torch.float32)
                unsup_loss_mask = unsup_loss_mask.to(_get_device())
                    
            # aug
            # softmax temperature controlling
            uda_softmax_temp = cfg.uda_softmax_temp if cfg.uda_softmax_temp > 0 else 1.
            aug_log_prob = F.log_softmax(logits[sup_size:] / uda_softmax_temp, dim=-1)

            # KLdiv loss
            """
                nn.KLDivLoss (kl_div)
                input : log_prob (log_softmax)
                target : prob    (softmax)
                https://pytorch.org/docs/stable/nn.html

                unsup_loss is divied by number of unsup_loss_mask
                it is different from the google UDA official
                The official unsup_loss is divided by total
                https://github.com/google-research/uda/blob/master/text/uda.py#L175
            """
            unsup_loss = torch.sum(unsup_criterion(aug_log_prob, ori_prob), dim=-1)
            unsup_loss = torch.sum(unsup_loss * unsup_loss_mask, dim=-1) / torch.max(torch.sum(unsup_loss_mask, dim=-1), torch_device_one())
            final_loss = sup_loss + cfg.uda_coeff*unsup_loss

            return final_loss, sup_loss, unsup_loss
        return sup_loss, None, None

    # evaluation
    def get_acc(model, batch):
        # input_ids, segment_ids, input_mask, label_id, sentence = batch
        input_ids, segment_ids, input_mask, label_id = batch
        logits = model(input_ids, segment_ids, input_mask)
        _, label_pred = logits.max(1)

        result = (label_pred == label_id).float()
        accuracy = result.mean()
        # output_dump.logs(sentence, label_pred, label_id)    # output dump

        return accuracy, result

    if cfg.mode == 'train':
        trainer.train(get_loss, None, cfg.model_file, cfg.pretrain_file)

    if cfg.mode == 'train_eval':
        trainer.train(get_loss, get_acc, cfg.model_file, cfg.pretrain_file)

    if cfg.mode == 'eval':
        results = trainer.eval(get_acc, cfg.model_file, None)
        total_accuracy = torch.cat(results).mean().item()
        print('Accuracy :' , total_accuracy)