예제 #1
0
    def init_optimizer(self, num_total_steps):

        num_warmup_steps = int(self.args.warmup_proportion * num_total_steps)
        logger.info('warmup steps : %d' % num_warmup_steps)

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        param_optimizer = list(self.network.named_parameters())
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            self.args.weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.args.learning_rate,
                               correct_bias=False)
        self.scheduler = WarmupLinearSchedule(self.optimizer,
                                              warmup_steps=num_warmup_steps,
                                              t_total=num_total_steps)
예제 #2
0
class KeyphraseSpanExtraction(object):
    def __init__(self, args, state_dict=None):

        self.args = args
        self.updates = 0

        # select model
        network = networks.get_class(args)

        # select config
        args.num_labels = 2 if args.model_class != 'bert2tag' else len(Idx2Tag)
        logger.info('Config num_labels = %d' % args.num_labels)
        model_config = config_class[args.pretrain_model_type].from_pretrained(
            args.cache_dir, num_labels=args.num_labels)

        # load pretrained model
        self.network = network.from_pretrained(args.cache_dir,
                                               config=model_config)

        # load checkpoint
        if state_dict is not None:
            self.network.load_state_dict(state_dict)
            logger.info('loaded checkpoint state_dict')

    # -------------------------------------------------------------------------------------------
    # -------------------------------------------------------------------------------------------
    def init_optimizer(self, num_total_steps):

        num_warmup_steps = int(self.args.warmup_proportion * num_total_steps)
        logger.info('warmup steps : %d' % num_warmup_steps)

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        param_optimizer = list(self.network.named_parameters())
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            self.args.weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.args.learning_rate,
                               correct_bias=False)
        self.scheduler = WarmupLinearSchedule(self.optimizer,
                                              warmup_steps=num_warmup_steps,
                                              t_total=num_total_steps)

    # -------------------------------------------------------------------------------------------
    # -------------------------------------------------------------------------------------------
    # train
    def update(self, step, inputs):
        # Train mode
        self.network.train()

        # run !
        loss = self.network(**inputs)

        if self.args.n_gpu > 1:
            # mean() to average on multi-gpu parallel (not distributed) training
            loss = loss.mean()

        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer),
                                           self.args.max_grad_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                           self.args.max_grad_norm)

        if (step + 1) % self.args.gradient_accumulation_steps == 0:
            self.optimizer.step()
            self.scheduler.step()

            self.optimizer.zero_grad()
            self.updates += 1
        return loss.item()

    # -------------------------------------------------------------------------------------------
    # -------------------------------------------------------------------------------------------
    # test

    # bert2span
    def test_bert2span(self, inputs, lengths):
        self.network.eval()
        with torch.no_grad():
            s_logits, e_logits = self.network(**inputs)

        assert s_logits.size(0) == e_logits.size(0) == sum(lengths)
        s_logits = s_logits.data.cpu().tolist()
        e_logits = e_logits.data.cpu()

        start_lists, end_lists = [], []
        sum_len = 0
        for l in lengths:
            start_lists.append(s_logits[sum_len:(sum_len + l)])
            end_lists.append(e_logits[sum_len:(sum_len + l), :l].tolist())
            sum_len += l
        return start_lists, end_lists

    # bert2tag
    def test_bert2tag(self, inputs, lengths):
        self.network.eval()
        with torch.no_grad():
            logits = self.network(**inputs)
            logits = F.softmax(logits, dim=-1)
        logits = logits.data.cpu().tolist()
        assert len(logits) == sum(lengths)

        logit_lists = []
        sum_len = 0
        for l in lengths:
            logit_lists.append(logits[sum_len:sum_len + l])
            sum_len += l
        return logit_lists

    # bert2chunk
    def test_bert2chunk(self, inputs, lengths, max_phrase_words):
        self.network.eval()
        with torch.no_grad():
            logits = self.network(**inputs)
            logits = F.softmax(logits, dim=-1)
        logits = logits.data.cpu()[:, 1]
        logits = logits.tolist()

        logit_lists = []
        sum_len = 0
        for l in lengths:
            batch_logit = []
            for n in range(max_phrase_words):
                batch_logit.append(logits[sum_len:sum_len + l - n])
                sum_len += (l - n)
            logit_lists.append(batch_logit)
        return logit_lists

    # bert2rank & bert2joint
    def test_bert2rank(self, inputs, numbers):
        self.network.eval()
        with torch.no_grad():
            logits = self.network(
                **inputs)  # shape = (batch_size, max_diff_gram_num)

        assert (logits.shape[0] == len(numbers)) and (logits.shape[1]
                                                      == max(numbers))
        logits = logits.data.cpu().tolist()

        logit_lists = []
        for batch_id, num in enumerate(numbers):
            logit_lists.append(logits[batch_id][:num])
        return logit_lists

    # -------------------------------------------------------------------------------------------
    # -------------------------------------------------------------------------------------------
    def save_checkpoint(self, filename, epoch):
        network = self.network.module if hasattr(self.network,
                                                 'module') else self.network
        params = {
            'args': self.args,
            'epoch': epoch,
            'state_dict': network.state_dict(),
        }
        try:
            torch.save(params, filename)
            logger.info('success save epoch_%d checkpoints !' % epoch)
        except BaseException:
            logger.warning('WARN: Saving failed... continuing anyway.')

    @staticmethod
    def load_checkpoint(filename, new_args=None):
        logger.info('Loading model %s' % filename)
        saved_params = torch.load(filename,
                                  map_location=lambda storage, loc: storage)

        args = saved_params['args']
        epoch = saved_params['epoch']
        state_dict = saved_params['state_dict']
        if new_args:
            args = override_args(args, new_args)

        model = KeyphraseSpanExtraction(args, state_dict)
        logger.info('success loaded epoch_%d checkpoints ! From : %s' %
                    (epoch, filename))
        return model, epoch

    # -------------------------------------------------------------------------------------------
    # -------------------------------------------------------------------------------------------
    def zero_grad(self):
        self.optimizer.zero_grad()
        # self.network.zero_grad()

    def set_device(self):
        self.network.to(self.args.device)

    def parallelize(self):
        """Use data parallel to copy the model across several gpus.
        This will take all gpus visible with CUDA_VISIBLE_DEVICES.
        """
        self.parallel = True
        self.network = torch.nn.DataParallel(self.network)

    def distribute(self):
        self.distributed = True
        self.network = torch.nn.parallel.DistributedDataParallel(
            self.network,
            device_ids=[self.args.local_rank],
            output_device=self.args.local_rank,
            find_unused_parameters=True)