Beispiel #1
0
        train_df = pd.read_csv(os.path.join(config.processed_train,
                                            file_name + '.csv'),
                               index_col=0)
        step_grad = config.gradient_accumulation_step_dict[file_name]
        epoch_total += (len(train_df) // batch_size + 1) // step_grad + 1
    print('epoch_total', epoch_total)
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=config.learning_rate,
                         warmup=config.warmup_proportion,
                         t_total=epoch_total * config.epochs)
    # 下面这行根据实际需求是否注释
    model, optimizer = train_valid(model, criterion, optimizer)
    loss_weights = EpochLossWeight()
    loss_weights_dict = loss_weights.run()
    # 进行最后一次模型数据增强与结果测试
    # if os.path.exists(config.checkpoint_file):
    #     print('正在加载最后一个模型', config.checkpoint_file)
    #     checkpoint = torch.load(config.checkpoint_file)
    #     model.load_state_dict(checkpoint['model_state'])
    #     optimizer.load_state_dict(checkpoint['optimizer_state'])
    train_enhance(model, criterion, optimizer, loss_weights_dict)
    test_model(model, 'last_predict', 'last_predict_submit.zip')
    # 进行最佳模型的的数据增强与结果测试
    if os.path.exists(config.best_checkpoint_file):
        print('正在加载最佳模型', config.best_checkpoint_file)
        checkpoint = torch.load(config.best_checkpoint_file)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
    train_enhance(model, criterion, optimizer, loss_weights_dict)
    test_model(model, 'best_predict', 'best_predict_submit.zip')
Beispiel #2
0
  param_optimizer = list(model.named_parameters())
  param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
  no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
  optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  ]
  optimizer = BertAdam(optimizer_grouped_parameters, lr=lr)

  if os.path.isfile(model_path):
    print('=' * 80)
    print('load training param, ', model_path)
    print('=' * 80)
    state = torch.load(model_path)
    model.load_state_dict(state['best_model_state'])
    optimizer.load_state_dict(state['best_opt_state'])
    epoch_list = range(state['best_epoch'] + 1, state['best_epoch'] + 1 + EPOCH)
    global_step = state['best_step']

    state = {}  ## FIXME:  临时的解决方案
  else:
    state = None
    epoch_list = range(EPOCH)
    global_step = 0

  grade = 0

  if on_windows:
    print_every = 50
    val_every = [50, 70, 50, 35]
  else:
Beispiel #3
0
class QAModel(object):
    """
    High level model that handles intializing the underlying network
    architecture, saving, updating examples, and predicting examples.
    """
    def __init__(self, opt, embedding=None, state_dict=None):
        # Book-keeping.
        self.opt = opt
        self.updates = state_dict['updates'] if state_dict else 0
        self.eval_embed_transfer = True
        self.train_loss = AverageMeter()

        # Building network.
        self.network = FlowQA(opt, embedding)
        if state_dict:
            new_state = set(self.network.state_dict().keys())
            for k in list(state_dict['network'].keys()):
                if k not in new_state:
                    del state_dict['network'][k]
            self.network.load_state_dict(state_dict['network'])

        parameters = [p for p in self.network.parameters() if p.requires_grad]
        self.total_param = sum([p.nelement() for p in parameters])

        # Building optimizer.
        if opt['finetune_bert'] != 0:
            bert_params = [
                p for p in self.network.bert.parameters() if p.requires_grad
            ]
            self.bertadam = BertAdam(bert_params,
                                     lr=opt['bert_lr'],
                                     warmup=opt['bert_warmup'],
                                     t_total=opt['bert_t_total'])
            non_bert_params = []
            for p in parameters:
                for bp in bert_params:
                    if p is bp:
                        break
                else:
                    non_bert_params.append(p)
            parameters = non_bert_params

        if opt['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(parameters,
                                       opt['learning_rate'],
                                       momentum=opt['momentum'],
                                       weight_decay=opt['weight_decay'])
        elif opt['optimizer'] == 'adamax':
            self.optimizer = optim.Adamax(parameters,
                                          weight_decay=opt['weight_decay'])
        elif opt['optimizer'] == 'adadelta':
            self.optimizer = optim.Adadelta(parameters,
                                            rho=0.95,
                                            weight_decay=opt['weight_decay'])
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])
        if state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])
            if opt['cuda']:
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda()

            if opt['finetune_bert'] != 0 and 'bertadam' in state_dict:
                self.bertadam.load_state_dict(state_dict['bertadam'])
                if opt['cuda']:
                    for state in self.bertadam.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()

        if opt['fix_embeddings']:
            wvec_size = 0
        else:
            wvec_size = (opt['vocab_size'] -
                         opt['tune_partial']) * opt['embedding_dim']

    def update(self, batch):
        # Train mode
        self.network.train()
        torch.set_grad_enabled(True)

        if self.opt['use_bert']:
            context_bertidx = batch[19]
            context_bert_spans = batch[20]
            question_bertidx = batch[21]
            question_bert_spans = batch[22]

        # Transfer to GPU
        if self.opt['cuda']:
            inputs = [e.cuda(non_blocking=True) for e in batch[:9]]
            overall_mask = batch[9].cuda(non_blocking=True)

            answer_s = batch[10].cuda(non_blocking=True)
            answer_e = batch[11].cuda(non_blocking=True)
            answer_c = batch[12].cuda(non_blocking=True)
            rationale_s = batch[13].cuda(non_blocking=True)
            rationale_e = batch[14].cuda(non_blocking=True)

            if self.opt['use_bert']:
                context_bertidx = [
                    x.cuda(non_blocking=True) for x in context_bertidx
                ]
        else:
            inputs = [e for e in batch[:9]]
            overall_mask = batch[9]

            answer_s = batch[10]
            answer_e = batch[11]
            answer_c = batch[12]
            rationale_s = batch[13]
            rationale_e = batch[14]

        # Run forward
        # output: [batch_size, question_num, context_len], [batch_size, question_num]
        if self.opt['use_bert']:
            score_s, score_e, score_c = self.network(*inputs, context_bertidx,
                                                     context_bert_spans,
                                                     question_bertidx,
                                                     question_bert_spans)
        else:
            score_s, score_e, score_c = self.network(*inputs)

        # Compute loss and accuracies
        if self.opt['use_elmo']:
            loss = self.opt['elmo_lambda'] * (
                self.network.elmo.scalar_mix_0.scalar_parameters[0]**2 +
                self.network.elmo.scalar_mix_0.scalar_parameters[1]**2 +
                self.network.elmo.scalar_mix_0.scalar_parameters[2]**2
            )  # ELMo L2 regularization
        else:
            loss = 0
        all_no_span = (answer_c != 3)
        answer_s.masked_fill_(all_no_span,
                              -100)  # ignore_index is -100 in F.cross_entropy
        answer_e.masked_fill_(all_no_span, -100)
        rationale_s.masked_fill_(
            all_no_span, -100)  # ignore_index is -100 in F.cross_entropy
        rationale_e.masked_fill_(all_no_span, -100)

        for i in range(overall_mask.size(0)):
            q_num = sum(overall_mask[i]
                        )  # the true question number for this sampled context

            target_s = answer_s[i, :q_num]  # Size: q_num
            target_e = answer_e[i, :q_num]
            target_c = answer_c[i, :q_num]
            target_s_r = rationale_s[i, :q_num]
            target_e_r = rationale_e[i, :q_num]
            target_no_span = all_no_span[i, :q_num]

            # single_loss is averaged across q_num
            single_loss = (F.cross_entropy(score_c[i, :q_num], target_c) *
                           q_num.item() / 15.0 +
                           F.cross_entropy(score_s[i, :q_num], target_s) *
                           (q_num - sum(target_no_span)).item() / 12.0 +
                           F.cross_entropy(score_e[i, :q_num], target_e) *
                           (q_num - sum(target_no_span)).item() / 12.0)
            #+ self.opt['rationale_lambda'] * F.cross_entropy(score_s_r[i, :q_num], target_s_r) * (q_num - sum(target_no_span)).item() / 12.0
            #+ self.opt['rationale_lambda'] * F.cross_entropy(score_e_r[i, :q_num], target_e_r) * (q_num - sum(target_no_span)).item() / 12.0)

            loss = loss + (single_loss / overall_mask.size(0))
        self.train_loss.update(loss.item(), overall_mask.size(0))
        '''
        # Clear gradients and run backward
        self.optimizer.zero_grad()
        loss.backward()
        '''
        loss = loss / self.opt['aggregate_grad_steps']
        loss.backward()
        '''
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                       self.opt['grad_clipping'])

        # Update parameters
        self.optimizer.step()
        self.updates += 1
        '''

        return loss

    def take_step(self):
        # Clip Gradients
        torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                       self.opt['grad_clipping'])

        # Update parameters
        self.optimizer.step()
        if self.opt['finetune_bert']:
            self.bertadam.step()
        self.updates += 1

        # Reset any partially fixed parameters (e.g. rare words)
        self.reset_embeddings()
        self.eval_embed_transfer = True

        # Clear gradients and run backward
        self.optimizer.zero_grad()
        if self.opt['finetune_bert']:
            self.bertadam.zero_grad()

    def predict(self, batch):
        # Eval mode
        self.network.eval()
        torch.set_grad_enabled(False)

        # Transfer trained embedding to evaluation embedding
        if self.eval_embed_transfer:
            self.update_eval_embed()
            self.eval_embed_transfer = False

        if self.opt['use_bert']:
            context_bertidx = batch[19]
            context_bert_spans = batch[20]
            question_bertidx = batch[21]
            question_bert_spans = batch[22]

        # Transfer to GPU
        if self.opt['cuda']:
            inputs = [e.cuda(non_blocking=True) for e in batch[:9]]
            if self.opt['use_bert']:
                context_bertidx = [
                    x.cuda(non_blocking=True) for x in context_bertidx
                ]
        else:
            inputs = [e for e in batch[:9]]

        # Run forward
        # output: [batch_size, question_num, context_len], [batch_size, question_num]
        if self.opt['use_bert']:
            score_s, score_e, score_c = self.network(*inputs, context_bertidx,
                                                     context_bert_spans,
                                                     question_bertidx,
                                                     question_bert_spans)
        else:
            score_s, score_e, score_c = self.network(*inputs)
        score_s = F.softmax(score_s, dim=2)
        score_e = F.softmax(score_e, dim=2)

        # Transfer to CPU/normal tensors for numpy ops
        score_s = score_s.data.cpu()
        score_e = score_e.data.cpu()
        score_c = score_c.data.cpu()

        # Get argmax text spans
        text = batch[15]
        spans = batch[16]
        overall_mask = batch[9]

        predictions = []
        max_len = self.opt['max_len'] or score_s.size(2)

        for i in range(overall_mask.size(0)):
            for j in range(overall_mask.size(1)):
                if overall_mask[i, j] == 0:  # this dialog has ended
                    break

                ans_type = np.argmax(score_c[i, j])

                if ans_type == 0:
                    predictions.append("unknown")
                elif ans_type == 1:
                    predictions.append("Yes")
                elif ans_type == 2:
                    predictions.append("No")
                else:
                    scores = torch.ger(score_s[i, j], score_e[i, j])
                    scores.triu_().tril_(max_len - 1)
                    scores = scores.numpy()
                    s_idx, e_idx = np.unravel_index(np.argmax(scores),
                                                    scores.shape)

                    s_offset, e_offset = spans[i][s_idx][0], spans[i][e_idx][1]
                    predictions.append(text[i][s_offset:e_offset])

        return predictions  # list of (list of strings)

    # allow the evaluation embedding be larger than training embedding
    # this is helpful if we have pretrained word embeddings
    def setup_eval_embed(self, eval_embed, padding_idx=0):
        # eval_embed should be a supermatrix of training embedding
        self.network.eval_embed = nn.Embedding(eval_embed.size(0),
                                               eval_embed.size(1),
                                               padding_idx=padding_idx)
        self.network.eval_embed.weight.data = eval_embed
        for p in self.network.eval_embed.parameters():
            p.requires_grad = False
        self.eval_embed_transfer = True

        if hasattr(self.network, 'CoVe'):
            self.network.CoVe.setup_eval_embed(eval_embed)

    def update_eval_embed(self):
        # update evaluation embedding to trained embedding
        if self.opt['tune_partial'] > 0:
            offset = self.opt['tune_partial']
            self.network.eval_embed.weight.data[0:offset] \
                = self.network.embedding.weight.data[0:offset]
        else:
            offset = 10
            self.network.eval_embed.weight.data[0:offset] \
                = self.network.embedding.weight.data[0:offset]

    def reset_embeddings(self):
        # Reset fixed embeddings to original value
        if self.opt['tune_partial'] > 0:
            offset = self.opt['tune_partial']
            if offset < self.network.embedding.weight.data.size(0):
                self.network.embedding.weight.data[offset:] \
                    = self.network.fixed_embedding

    def get_pretrain(self, state_dict):
        own_state = self.network.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            if isinstance(param, Parameter):
                param = param.data
            try:
                own_state[name].copy_(param)
            except:
                print("Skip", name)
                continue

    def save(self, filename, epoch):
        params = {
            'state_dict': {
                'network': self.network.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'updates': self.updates  # how many updates
            },
            'config': self.opt,
            'epoch': epoch
        }
        if self.opt['finetune_bert']:
            params['state_dict']['bertadam'] = self.bertadam.state_dict()
        try:
            torch.save(params, filename)
            logger.info('model saved to {}'.format(filename))
        except BaseException:
            logger.warn('[ WARN: Saving failed... continuing anyway. ]')

    def save_for_predict(self, filename, epoch):
        network_state = dict([(k, v)
                              for k, v in self.network.state_dict().items()
                              if k[0:4] != 'CoVe'])
        if 'eval_embed.weight' in network_state:
            del network_state['eval_embed.weight']
        if 'fixed_embedding' in network_state:
            del network_state['fixed_embedding']
        params = {
            'state_dict': {
                'network': network_state
            },
            'config': self.opt,
        }
        try:
            torch.save(params, filename)
            logger.info('model saved to {}'.format(filename))
        except BaseException:
            logger.warn('[ WARN: Saving failed... continuing anyway. ]')

    def cuda(self):
        self.network.cuda()
class BertTrainer:
    def __init__(self, hypers: Hypers, model_name, checkpoint,
                 **extra_model_args):
        """
        initialize the BertOptimizer, with common logic for setting weight_decay_rate, doing gradient accumulation and
        tracking loss
        :param hypers: the core hyperparameters for the bert model
        :param model_name: the fully qualified name of the bert model we will train
            like pytorch_pretrained_bert.modeling.BertForQuestionAnswering
        :param checkpoint: if resuming training,
        this is the checkpoint that contains the optimizer state as checkpoint['optimizer']
        """

        self.init_time = time.time()

        self.model = self.get_model(hypers, model_name, checkpoint,
                                    **extra_model_args)

        self.step = 0
        self.hypers = hypers
        self.train_stats = TrainStats(hypers)

        self.model.train()
        logger.info('configured model for training')

        # show parameter names
        # logger.info(str([n for (n, p) in self.model.named_parameters()]))

        # Prepare optimizer
        if hasattr(hypers, 'exclude_pooler') and hypers.exclude_pooler:
            # module.bert.pooler.dense.weight, module.bert.pooler.dense.bias
            # see https://github.com/NVIDIA/apex/issues/131
            self.param_optimizer = [
                (n, p) for (n, p) in self.model.named_parameters()
                if '.pooler.' not in n
            ]
        else:
            self.param_optimizer = list(self.model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay_rate':
            0.01
        }, {
            'params': [
                p for n, p in self.param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay_rate':
            0.0
        }]
        self.t_total = hypers.num_train_steps
        self.global_step = hypers.global_step

        if hypers.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=hypers.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if hypers.loss_scale == 0:
                self.optimizer = FP16_Optimizer(
                    optimizer,
                    dynamic_loss_scale=True,
                    verbose=(hypers.global_rank == 0))
            else:
                self.optimizer = FP16_Optimizer(
                    optimizer,
                    static_loss_scale=hypers.loss_scale,
                    verbose=(hypers.global_rank == 0))
        else:
            self.optimizer = BertAdam(optimizer_grouped_parameters,
                                      lr=hypers.learning_rate,
                                      warmup=hypers.warmup_proportion,
                                      t_total=self.t_total)
        logger.info('created optimizer')

        if checkpoint and type(
                checkpoint) is dict and 'optimizer' in checkpoint:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            if hypers.fp16:
                pass
            else:
                # if we load this state, we need to set the t_total to what we passed, not what was saved
                self.optimizer.set_t_total(self.t_total)
                # show state of optimizer
                lrs = self.optimizer.get_lr()
                logger.info('Min and max learn rate:    %s',
                            str([min(lrs), max(lrs)]))
                logger.info('Min and max step in state: %s',
                            str(self.optimizer.get_steps()))
            instances_per_step = hypers.train_batch_size * hypers.gradient_accumulation_steps * hypers.world_size
            if 'seen_instances' in checkpoint:
                self.global_step = int(checkpoint['seen_instances'] /
                                       instances_per_step)
                self.train_stats.previous_instances = checkpoint[
                    'seen_instances']
                logger.info('got global step from checkpoint = %i',
                            self.global_step)

            logger.info('Loaded optimizer state:')
            logger.info(repr(self.optimizer))

    def reset(self):
        """
        reset any gradient accumulation
        :return:
        """
        self.model.zero_grad()
        self.step = 0

    def should_continue(self):
        """
        :return: True if training should continue
        """
        if self.global_step >= self.t_total:
            logger.info(
                'stopping due to train step %i >= target train steps %i',
                self.global_step, self.t_total)
            return False
        if 0 < self.hypers.time_limit <= (time.time() - self.init_time):
            logger.info('stopping due to time out %i seconds',
                        self.hypers.time_limit)
            return False
        return True

    def save_simple(self, filename):
        if self.hypers.global_rank != 0:
            logger.info('skipping save in %i', torch.distributed.get_rank())
            return
        model_to_save = self.model.module if hasattr(
            self.model, 'module') else self.model  # Only save the model itself
        torch.save(model_to_save.state_dict(), filename)
        logger.info(f'saved model only to {filename}')

    def save(self, filename, **extra_checkpoint_info):
        """
        save a checkpoint with the model parameters, the optimizer state and any additional checkpoint info
        :param filename:
        :param extra_checkpoint_info:
        :return:
        """
        # only local_rank 0, in fact only global rank 0
        if self.hypers.global_rank != 0:
            logger.info('skipping save in %i', torch.distributed.get_rank())
            return
        start_time = time.time()
        checkpoint = extra_checkpoint_info
        model_to_save = self.model.module if hasattr(
            self.model, 'module') else self.model  # Only save the model itself
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # also save the optimizer state, since we will likely resume from partial pre-training
        checkpoint['state_dict'] = model_to_save.state_dict()
        checkpoint['optimizer'] = self.optimizer.state_dict()
        # include world size in instances_per_step calculation
        instances_per_step = self.hypers.train_batch_size * \
                             self.hypers.gradient_accumulation_steps * \
                             self.hypers.world_size
        checkpoint['seen_instances'] = self.global_step * instances_per_step
        checkpoint['num_instances'] = self.t_total * instances_per_step
        # CONSIDER: also save hypers?
        torch.save(checkpoint, filename)
        logger.info(
            f'saved model to {filename} in {time.time()-start_time} seconds')

    def get_instance_count(self):
        instances_per_step = self.hypers.train_batch_size * \
                             self.hypers.gradient_accumulation_steps * \
                             self.hypers.world_size
        return self.global_step * instances_per_step

    def step_loss(self, loss):
        """
        accumulates the gradient, tracks the loss and applies the gradient to the model
        :param loss: the loss from evaluating the model
        """
        if self.global_step == 0:
            logger.info('first step_loss')
        if self.hypers.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu.
        self.train_stats.note_loss(loss.item())

        if self.hypers.gradient_accumulation_steps > 1:
            loss = loss / self.hypers.gradient_accumulation_steps

        if self.hypers.fp16:
            self.optimizer.backward(loss)
        else:
            loss.backward()

        if (self.step + 1) % self.hypers.gradient_accumulation_steps == 0:
            lr_this_step = self.hypers.learning_rate * warmup_linear(
                self.global_step / self.t_total, self.hypers.warmup_proportion)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr_this_step
            self.optimizer.step()
            self.model.zero_grad()
            self.global_step += 1

        self.step += 1

    @classmethod
    def get_files(cls, train_file, completed_files):
        logger.info('completed files = %s, count = %i',
                    str(completed_files[:min(5, len(completed_files))]),
                    len(completed_files))
        # multiple train files
        if not os.path.isdir(train_file):
            train_files = [train_file]
        else:
            if not train_file.endswith('/'):
                train_file = train_file + '/'
            train_files = glob.glob(train_file + '**', recursive=True)
            train_files = [f for f in train_files if not os.path.isdir(f)]

        # exclude completed files
        if not set(train_files) == set(completed_files):
            train_files = [f for f in train_files if f not in completed_files]
        else:
            completed_files = []  # new epoch
        logger.info('train files = %s, count = %i',
                    str(train_files[:min(5, len(train_files))]),
                    len(train_files))

        return train_files, completed_files

    @classmethod
    def get_model(cls, hypers, model_name, checkpoint, **extra_model_args):
        override_state_dict = None
        if checkpoint:
            if type(checkpoint) is dict and 'state_dict' in checkpoint:
                logger.info('loading from multi-part checkpoint')
                override_state_dict = checkpoint['state_dict']
            else:
                logger.info('loading from saved model parameters')
                override_state_dict = checkpoint

        # create the model object by name
        # https://stackoverflow.com/questions/4821104/python-dynamic-instantiation-from-string-name-of-a-class-in-dynamically-imported
        import importlib
        clsdot = model_name.rfind('.')
        class_ = getattr(importlib.import_module(model_name[0:clsdot]),
                         model_name[clsdot + 1:])

        model_args = {
            'state_dict': override_state_dict,
            'cache_dir': PYTORCH_PRETRAINED_BERT_CACHE
        }
        model_args.update(extra_model_args)
        # logger.info(pprint.pformat(extra_model_args, indent=4))
        model = class_.from_pretrained(hypers.bert_model, **model_args)

        logger.info('built model')

        # configure model for fp16, multi-gpu and/or distributed training
        if hypers.fp16:
            model.half()
            logger.info('model halved')
        logger.info('sending model to %s', str(hypers.device))
        model.to(hypers.device)
        logger.info('sent model to %s', str(hypers.device))

        if hypers.local_rank != -1:
            if not hypers.no_apex:
                try:
                    from apex.parallel import DistributedDataParallel as DDP
                    model = DDP(model)
                except ImportError:
                    raise ImportError("Please install apex")
            else:
                model = torch.nn.parallel.DistributedDataParallel(
                    model,
                    device_ids=[hypers.local_rank],
                    output_device=hypers.local_rank)
            logger.info('using DistributedDataParallel for world size %i',
                        hypers.world_size)
        elif hypers.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        return model

    @classmethod
    def get_base_parser(cls):
        parser = argparse.ArgumentParser()

        # Required parameters
        parser.add_argument(
            "--bert_model",
            default=None,
            type=str,
            required=True,
            help=
            "Bert pre-trained model selected in the list: bert-base-uncased, "
            "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
        )

        # Other parameters
        parser.add_argument(
            "--num_instances",
            default=-1,
            type=int,
            help="Total number of training instances to train over.")
        parser.add_argument(
            "--seen_instances",
            default=-1,
            type=int,
            help=
            "When resuming training, the number of instances we have already trained over."
        )
        parser.add_argument("--train_batch_size",
                            default=32,
                            type=int,
                            help="Total batch size for training.")
        parser.add_argument("--learning_rate",
                            default=5e-5,
                            type=float,
                            help="The initial learning rate for Adam.")
        parser.add_argument(
            "--warmup_proportion",
            default=0.1,
            type=float,
            help=
            "Proportion of training to perform linear learning rate warmup for. "
            "E.g., 0.1 = 10% of training.")
        parser.add_argument("--no_cuda",
                            default=False,
                            action='store_true',
                            help="Whether not to use CUDA when available")
        parser.add_argument("--no_apex",
                            default=False,
                            action='store_true',
                            help="Whether not to use apex when available")
        parser.add_argument('--seed',
                            type=int,
                            default=42,
                            help="random seed for initialization")
        parser.add_argument(
            '--gradient_accumulation_steps',
            type=int,
            default=1,
            help=
            "Number of updates steps to accumulate before performing a backward/update pass."
        )
        parser.add_argument(
            '--optimize_on_cpu',
            default=False,
            action='store_true',
            help=
            "Whether to perform optimization and keep the optimizer averages on CPU"
        )
        parser.add_argument(
            '--fp16',
            default=False,
            action='store_true',
            help="Whether to use 16-bit float precision instead of 32-bit")
        parser.add_argument(
            '--loss_scale',
            type=float,
            default=0,
            help=
            'Loss scaling, positive power of 2 values can improve fp16 convergence. '
            'Leave at zero to use dynamic loss scaling')
        return parser
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--generation_dataset",
                        default='openi',
                        type=str,
                        help=["mimic-cxr, openi"])
    parser.add_argument("--vqa_rad",
                        default="all",
                        type=str,
                        choices=["all", "chest", "head", "abd"])
    parser.add_argument("--data_set",
                        default="train",
                        type=str,
                        help="train | valid")
    parser.add_argument('--img_hidden_sz',
                        type=int,
                        default=2048,
                        help="Whether to use amp for fp16")

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )

    parser.add_argument(
        "--mlm_task",
        type=str,
        default=True,
        help="The model will train only mlm task!! | True | False")
    parser.add_argument("--train_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--num_train_epochs",
                        default=5,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        default=False,
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument("--img_encoding",
                        type=str,
                        default='fully_use_cnn',
                        choices=['random_sample', 'fully_use_cnn'])
    parser.add_argument(
        '--len_vis_input',
        type=int,
        default=256,
        help="The length of visual token input"
    )  #visual token의 fixed length를 100이라 하면, <Unknown> token 100개가 되고, 100개의 word 생성 가능.
    parser.add_argument('--max_len_b',
                        type=int,
                        default=253,
                        help="Truncate_config: maximum length of segment B.")

    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )

    parser.add_argument('--max_pred',
                        type=int,
                        default=10,
                        help="Max tokens of prediction.")
    parser.add_argument(
        '--s2s_prob',
        default=1,
        type=float,
        help=
        "Percentage of examples that are bi-uni-directional LM (seq2seq). This must be turned off!!!!!!! because this is not for seq2seq model!!!"
    )
    parser.add_argument(
        '--bi_prob',
        default=0,
        type=float,
        help="Percentage of examples that are bidirectional LM.")
    parser.add_argument('--hidden_size', type=int, default=768)
    parser.add_argument('--bar', default=False, type=str, help="True or False")

    parser.add_argument("--config_path",
                        default='./pretrained_model/non_cross/config.json',
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--model_recover_path",
        default='./pretrained_model/non_cross/pytorch_model.bin',
        type=str,
        help="The file of fine-tuned pretraining model.")  # model load
    parser.add_argument(
        "--output_dir",
        default='./output_model/base_noncross_mimic_2',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    parser.add_argument(
        "--log_file",
        default="training.log",
        type=str,
        help="The output directory where the log will be written.")

    parser.add_argument('--img_postion',
                        default=True,
                        help="It will produce img_position.")
    parser.add_argument(
        "--do_train",
        action='store_true',
        default=True,
        help="Whether to run training. This should ALWAYS be set to True.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    ############################################################################################################

    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--global_rank",
                        type=int,
                        default=-1,
                        help="global_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")

    parser.add_argument(
        '--fp16',
        action='store_true',
        default=False,
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        default=False,
        help=
        "Whether to use 32-bit float precision instead of 32-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        default=False,
                        help="Whether to use amp for fp16")

    parser.add_argument('--new_segment_ids',
                        default=False,
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument(
        '--trunc_seg',
        default='b',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")

    parser.add_argument("--num_workers",
                        default=20,
                        type=int,
                        help="Number of workers for the data loader.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")

    parser.add_argument(
        '--image_root',
        type=str,
        default='/home/mimic-cxr/dataset/image_preprocessing/re_512_3ch/Train')
    parser.add_argument('--split',
                        type=str,
                        nargs='+',
                        default=['train', 'valid'])

    parser.add_argument('--world_size',
                        default=1,
                        type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url',
                        default='file://[PT_OUTPUT_DIR]/nonexistent_file',
                        type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--sche_mode',
                        default='warmup_linear',
                        type=str,
                        help="warmup_linear | warmup_constant | warmup_cosine")
    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--use_num_imgs', default=-1, type=int)
    parser.add_argument('--max_drop_worst_ratio', default=0, type=float)
    parser.add_argument('--drop_after', default=6, type=int)
    parser.add_argument('--tasks',
                        default='report_generation',
                        help='report_generation | vqa')
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")

    args = parser.parse_args()

    print('global_rank: {}, local rank: {}'.format(args.global_rank,
                                                   args.local_rank))
    args.max_seq_length = args.max_len_b + args.len_vis_input + 3  # +3 for 2x[SEP] and [CLS]
    args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir)

    if args.tasks == 'vqa':
        wandb.init(config=args, project="VQA")
        wandb.config["more"] = "custom"
        args.src_file = '/home/mimic-cxr/dataset/data_RAD'
        args.file_valid_jpgs = '/home/mimic-cxr/dataset/vqa_rad_original_set.json'

    else:
        if args.generation_dataset == 'mimic-cxr':
            wandb.init(config=args, project="report_generation")
            wandb.config["more"] = "custom"
            args.src_file = '/home/mimic-cxr/new_dset/Train_253.jsonl'
            args.file_valid_jpgs = '/home/mimic-cxr/new_dset/Train_253.jsonl'

        else:
            wandb.init(config=args, project="report_generation")
            wandb.config["more"] = "custom"
            args.src_file = '/home/mimic-cxr/dataset/open_i/Train_openi.jsonl'
            args.file_valid_jpgs = '/home/mimic-cxr/dataset/open_i/Valid_openi.jsonl'

    print(" # PID :", os.getpid())
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        print("device", device)
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        print("device", device)
        n_gpu = 1
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=args.dist_url,
                                             world_size=args.world_size,
                                             rank=args.global_rank)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=True)

    if args.do_train:
        print("args.mask_prob", args.mask_prob)
        print("args.train_batch_size", args.train_batch_size)
        bi_uni_pipeline = [
            data_loader.Preprocess4Seq2seq(
                args,
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                args.bar,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mode="s2s",
                len_vis_input=args.len_vis_input,
                local_rank=args.local_rank,
                load_vqa_set=(args.tasks == 'vqa'))
        ]

        bi_uni_pipeline.append(
            data_loader.Preprocess4Seq2seq(
                args,
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                args.bar,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mode="bi",
                len_vis_input=args.len_vis_input,
                local_rank=args.local_rank,
                load_vqa_set=(args.tasks == 'vqa')))

        train_dataset = data_loader.Img2txtDataset(
            args,
            args.data_set,
            args.src_file,
            args.image_root,
            args.split,
            args.train_batch_size,
            tokenizer,
            args.max_seq_length,
            file_valid_jpgs=args.file_valid_jpgs,
            bi_uni_pipeline=bi_uni_pipeline,
            use_num_imgs=args.use_num_imgs,
            s2s_prob=args.s2s_prob,  # this must be set to 1.
            bi_prob=args.bi_prob,
            tasks=args.tasks)

        if args.world_size == 1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
        else:
            train_sampler = DistributedSampler(train_dataset)

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=batch_list_to_batch_tensors,
            pin_memory=True)

    t_total = int(
        len(train_dataloader) * args.num_train_epochs * 1. /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == 'report_generation' else 0

    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"])  # index in BERT vocab: 103, 102, 0

    # BERT model will be loaded! from scratch
    if (args.model_recover_path is None):
        _state_dict = {} if args.from_scratch else None
        _state_dict = {}
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            args=args,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks)

        print("scratch model's statedict : ")
        for param_tensor in model.state_dict():
            print(param_tensor, "\t", model.state_dict()[param_tensor].size())
        global_step = 0
        print("The model will train from scratch")

    else:
        print("Task :", args.tasks, args.s2s_prob)
        print("Recoverd model :", args.model_recover_path)
        for model_recover_path in glob.glob(args.model_recover_path.strip()):
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(model_recover_path)

            for key in list(model_recover.keys()):
                model_recover[key.replace('enc.', '').replace(
                    'mlm.', 'cls.')] = model_recover.pop(key)
            global_step = 0

        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            args=args,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks)

        model.load_state_dict(model_recover, strict=False)

        print("The pretrained model loaded and fine-tuning.")
        del model_recover
        torch.cuda.empty_cache()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)

    elif n_gpu > 1:
        model = DataParallelImbalance(model)

    wandb.watch(model)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         schedule=args.sche_mode,
                         t_total=t_total)
    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir,
                         "optim.{0}.bin".format(recover_step)))
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        model.train()
        print("Total Parameters:",
              sum([p.nelement() for p in model.parameters()]))

        if recover_step:
            start_epoch = recover_step + 1
            print("start_epoch", start_epoch)
        else:
            start_epoch = 1

        for i_epoch in trange(start_epoch,
                              args.num_train_epochs + 1,
                              desc="Epoch"):
            if args.local_rank >= 0:
                train_sampler.set_epoch(i_epoch - 1)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)')
            nbatches = len(train_dataloader)
            train_loss = []

            avg_loss = 0.0
            batch_count = 0
            for step, batch in enumerate(iter_bar):
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, task_idx, img, vis_pe, ans_labels, ans_type, organ = batch
                if args.fp16:
                    img = img.half()
                    vis_pe = vis_pe.half()

                loss_tuple = model(img,
                                   vis_pe,
                                   input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   ans_labels,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   drop_worst_ratio=args.max_drop_worst_ratio
                                   if i_epoch > args.drop_after else 0,
                                   ans_type=ans_type)

                masked_lm_loss, vqa_loss = loss_tuple

                batch_count += 1
                if args.tasks == 'report_generation':
                    masked_lm_loss = masked_lm_loss.mean()
                    loss = masked_lm_loss
                else:
                    vqa_loss = vqa_loss.mean()
                    loss = vqa_loss

                iter_bar.set_description('Iter (loss=%5.3f)' % (loss.item()))
                train_loss.append(loss.item())

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                    args.warmup_proportion)
                    if args.fp16:
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            wandb.log({"train_loss": np.mean(train_loss)})
            logger.info(
                "** ** * Saving fine-tuned model and optimizer ** ** * ")
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_config_file = os.path.join(args.output_dir, 'config.json')

            with open(output_config_file, 'w') as f:
                f.write(model_to_save.config.to_json_string())

            output_model_file = os.path.join(args.output_dir,
                                             "model.{0}.bin".format(i_epoch))
            output_optim_file = os.path.join(args.output_dir,
                                             "optim.{0}.bin".format(i_epoch))
            if args.global_rank in (
                    -1, 0):  # save model if the first device or no dist
                torch.save(
                    copy.deepcopy(model_to_save).cpu().state_dict(),
                    output_model_file)

            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.world_size > 1:
                torch.distributed.barrier()
Beispiel #6
0
def train_eval(args, train_data_path, valid_data_path):

    index = read_pickle(args.index_path)
    word2index, tag2index = index['word2id'], index['tag2id']
    args.num_labels = len(tag2index)
    args.vocab_size = len(word2index)+1
    set_seed(args.seed_num)
    train_dataloader, train_samples = get_dataloader(train_data_path, args.train_batch_size, True)
    valid_dataloader, _ = get_dataloader(valid_data_path, args.valid_batch_size, False)

    if args.model == 'bert':
        bert_config = BertConfig(args.bert_config_path)
        model = NERBert(bert_config, args)
        model.load_state_dict(torch.load(args.bert_model_path), strict=False)
        # model = NERBert.from_pretrained('bert_chinese',
        #                                 # cache_dir='/home/dutir/yuetianchi/.pytorch_pretrained_bert',
        #                                 num_labels=args.num_labels)
    else:
        if args.embedding:
            word_embedding_matrix = read_pickle(args.embedding_data_path)
            model = NERModel(args, word_embedding_matrix)
        else:
            model = NERModel(args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if args.model == 'bert':
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if 'bert' not in n], 'lr': 5e-5, 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and ('bert' in n)],
             'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and ('bert' in n)],
             'weight_decay': 0.0}
        ]
        warmup_proportion = 0.1
        num_train_optimization_steps = int(
            train_samples / args.train_batch_size / args.gradient_accumulation_steps) * args.epochs

        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=warmup_proportion,
                             t_total=num_train_optimization_steps)
    else:
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=current_learning_rate
        )

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    global_step = init_step
    best_score = 0.0

    logging.info('Start Training...')
    logging.info('init_step = %d' % global_step)
    for epoch_id in range(int(args.epochs)):

        tr_loss = 0
        model.train()
        for step, train_batch in enumerate(train_dataloader):


            batch = tuple(t.to(device) for t in train_batch)
            _, loss = model(batch[0], batch[1])
            if n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            tr_loss += loss.item()
            loss.backward()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            if (step + 1) % 500 == 0:
                print(loss.item())

            if args.do_valid and global_step % args.valid_step == 1:
                true_res = []
                pred_res = []
                len_res = []
                model.eval()
                for valid_step, valid_batch in enumerate(valid_dataloader):
                    valid_batch = tuple(t.to(device) for t in valid_batch)

                    with torch.no_grad():
                        logit = model(valid_batch[0])
                    if args.model == 'bert':
                        # 第一个token是‘cls’
                        len_res.extend(torch.sum(valid_batch[0].gt(0), dim=-1).detach().cpu().numpy()-1)
                        true_res.extend(valid_batch[1].detach().cpu().numpy()[:,1:])
                        pred_res.extend(logit.detach().cpu().numpy()[:,1:])
                    else:
                        len_res.extend(torch.sum(valid_batch[0].gt(0),dim=-1).detach().cpu().numpy())
                        true_res.extend(valid_batch[1].detach().cpu().numpy())
                        pred_res.extend(logit.detach().cpu().numpy())
                acc, score = cal_score(true_res, pred_res, len_res, tag2index)
                score = f1_score(true_res, pred_res, len_res, tag2index)
                logging.info('Evaluation:step:{},acc:{},fscore:{}'.format(str(epoch_id), acc, score))
                if score>=best_score:
                    best_score = score
                    if args.model == 'bert':
                        model_to_save = model.module if hasattr(model,
                                                                'module') else model  # Only save the model it-self
                        output_dir = '{}_{}'.format('bert', str(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                            output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
                            torch.save(model_to_save.state_dict(), output_model_file)
                            output_config_file = os.path.join(output_dir, CONFIG_NAME)
                            with open(output_config_file, 'w') as f:
                                f.write(model_to_save.config.to_json_string())
                    else:
                        save_variable_list = {
                            'step': global_step,
                            'current_learning_rate': args.learning_rate,
                            'warm_up_steps': step
                        }
                        save_model(model, optimizer, save_variable_list, args)
                model.train()
Beispiel #7
0
def main():
    parser = argparse.ArgumentParser()
    # Path parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The raw data dir.")
    parser.add_argument("--vocab_path",
                        default=None,
                        type=str,
                        required=True,
                        help="bert vocab path")
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--model_output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The param init of pretrain or finetune")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    # Data Process Parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")
    # Model Paramters
    parser.add_argument("--sop",
                        action='store_true',
                        help="whether use sop task.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")

    # Train Eval Test Paramters

    parser.add_argument("--checkpoint_steps",
                        required=True,
                        type=int,
                        help="save model eyery checkpoint_steps")

    parser.add_argument("--total_steps",
                        required=True,
                        type=int,
                        help="all steps of training model")

    parser.add_argument("--max_checkpoint",
                        required=True,
                        type=int,
                        help="max saved model in model_output_dir")

    parser.add_argument(
        "--examples_size_once",
        type=int,
        default=1000,
        help="read how many examples every time in pretrain or finetune")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="process rank in local")
    parser.add_argument("--local_debug",
                        action='store_true',
                        help="whether debug")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--fine_tune",
                        action='store_true',
                        help="Whether to run fine_tune.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates   accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=str,
        default='dynamic',
        help=
        '(float or str, optional, default=None):  Optional property override.  '
        'If passed as a string,must be a string representing a number, e.g., "128.0", or the string "dynamic".'
    )
    parser.add_argument(
        '--opt_level',
        type=str,
        default='O1',
        help=
        ' (str, optional, default="O1"):  Pure or mixed precision optimization level.  '
        'Accepted values are "O0", "O1", "O2", and "O3", explained in detail above.'
    )
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )

    # Other Patameters
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--rank',
                        type=int,
                        default=0,
                        help="global rank of current process")
    parser.add_argument("--world_size",
                        default=2,
                        type=int,
                        help="Number of process(显卡)")

    args = parser.parse_args()
    cur_env = os.environ
    args.rank = int(cur_env.get('RANK', -1))
    args.world_size = int(cur_env.get('WORLD_SIZE', -1))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)
    assert args.train_batch_size >= 1, 'batch_size < 1 '

    # 更新一次模型参数需要多少个样本
    examples_per_update = args.world_size * args.train_batch_size * args.gradient_accumulation_steps
    args.examples_size_once = args.examples_size_once // examples_per_update * examples_per_update
    if args.fine_tune:
        args.examples_size_once = examples_per_update

    os.makedirs(args.model_output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.model_output_dir, 'unilm_config.json'),
                   'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = torch.cuda.device_count()
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=args.world_size,
                                rank=args.rank)
    logger.info(
        "world_size:{}, rank:{}, device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}"
        .format(args.world_size, args.rank, device, n_gpu,
                bool(args.world_size > 1), args.fp16))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    if not args.fine_tune and not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.vocab_path,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    if args.local_rank == 0:
        dist.barrier()
    bi_uni_pipeline = [
        Preprocess4Seq2seq(args.max_pred,
                           args.mask_prob,
                           list(tokenizer.vocab.keys()),
                           tokenizer.convert_tokens_to_ids,
                           args.max_seq_length,
                           new_segment_ids=args.new_segment_ids,
                           truncate_config={
                               'max_len_a': args.max_len_a,
                               'max_len_b': args.max_len_b,
                               'trunc_seg': args.trunc_seg,
                               'always_truncate_tail':
                               args.always_truncate_tail
                           },
                           mask_source_words=args.mask_source_words,
                           skipgram_prb=args.skipgram_prb,
                           skipgram_size=args.skipgram_size,
                           mask_whole_word=args.mask_whole_word,
                           mode="s2s",
                           has_oracle=args.has_sentence_oracle,
                           num_qkv=args.num_qkv,
                           s2s_special_token=args.s2s_special_token,
                           s2s_add_segment=args.s2s_add_segment,
                           s2s_share_segment=args.s2s_share_segment,
                           pos_shift=args.pos_shift,
                           fine_tune=args.fine_tune)
    ]
    file_oracle = None
    if args.has_sentence_oracle:
        file_oracle = os.path.join(args.data_dir, 'train.oracle')

    # t_total表示模型参数更新的次数
    # t_total = args.train_steps
    # Prepare model
    recover_step = _get_max_epoch_model(args.model_output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_model_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * args.checkpoint_step)
        # 预训练时模型的参数初始化,比如使用chinese-bert-base的模型参数进行初始化
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = BertForPreTrainingLossMask.from_pretrained(
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)

    total_trainable_params = sum(p.numel() for p in model.parameters()
                                 if p.requires_grad)
    logger.info("模型参数: {}".format(total_trainable_params))
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=args.total_steps)
    if args.amp and args.fp16:
        from apex import amp
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale=args.loss_scale)
        from apex.parallel import DistributedDataParallel as DDP
        model = DDP(model)
    else:
        from torch.nn.parallel import DistributedDataParallel as DDP
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)

    if recover_step:
        logger.info("** ** * Recover optimizer: %d * ** **", recover_step)
        optim_recover = torch.load(os.path.join(
            args.model_output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.fp16 and args.amp:
            amp_recover = torch.load(os.path.join(
                args.model_output_dir, "amp.{0}.bin".format(recover_step)),
                                     map_location='cpu')
            logger.info("** ** * Recover amp: %d * ** **", recover_step)
            amp.load_state_dict(amp_recover)
    logger.info("** ** * CUDA.empty_cache() * ** **")
    torch.cuda.empty_cache()

    if args.rank == 0:
        writer = SummaryWriter(log_dir=args.log_dir)
    logger.info("***** Running training *****")
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Param Update Num = %d", args.total_steps)
    model.train()

    PRE = "rank{},local_rank {},".format(args.rank, args.local_rank)
    step = 1
    start = time.time()
    train_data_loader = TrainDataLoader(
        bi_uni_pipline=bi_uni_pipeline,
        examples_size_once=args.examples_size_once,
        world_size=args.world_size,
        train_batch_size=args.train_batch_size,
        num_workers=args.num_workers,
        data_dir=args.data_dir,
        tokenizer=tokenizer,
        max_len=args.max_seq_length)
    best_result = -float('inf')
    for global_step, batch in enumerate(train_data_loader, start=global_step):
        batch = [t.to(device) if t is not None else None for t in batch]
        if args.has_sentence_oracle:
            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
        else:
            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label = batch
            oracle_pos, oracle_weights, oracle_labels = None, None, None
        if not args.sop:
            # 不使用sop训练任务
            sop_label = None
        loss_tuple = model(input_ids,
                           segment_ids,
                           input_mask,
                           masked_lm_labels=lm_label_ids,
                           next_sentence_label=sop_label,
                           masked_pos=masked_pos,
                           masked_weights=masked_weights,
                           task_idx=task_idx,
                           masked_pos_2=oracle_pos,
                           masked_weights_2=oracle_weights,
                           masked_labels_2=oracle_labels,
                           mask_qkv=mask_qkv)
        masked_lm_loss, next_sentence_loss = loss_tuple
        # mean() to average on multi-gpu.
        if n_gpu > 1:
            masked_lm_loss = masked_lm_loss.mean()
            next_sentence_loss = next_sentence_loss.mean()
        # ensure that accumlated gradients are normalized
        if args.gradient_accumulation_steps > 1:
            masked_lm_loss = masked_lm_loss / args.gradient_accumulation_steps
            next_sentence_loss = next_sentence_loss / args.gradient_accumulation_steps
        if not args.sop:
            loss = masked_lm_loss
        else:
            loss = masked_lm_loss + next_sentence_loss
        if args.fp16 and args.amp:
            with amp.scale_loss(loss, optimizer) as scale_loss:
                scale_loss.backward()
        else:
            loss.backward()
        if (global_step + 1) % args.gradient_accumulation_steps == 0:
            if args.rank == 0:
                writer.add_scalar('unilm/mlm_loss', masked_lm_loss,
                                  global_step)
                writer.add_scalar('unilm/sop_loss', next_sentence_loss,
                                  global_step)
            lr_this_step = args.learning_rate * warmup_linear(
                global_step / args.total_steps, args.warmup_proportion)
            if args.fp16:
                # modify learning rate with special warm up BERT uses
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
            optimizer.step()
            optimizer.zero_grad()
            #global_step += 1
            #更新一次模型参数花费的时间,单位:秒
            cost_time_per_update = time.time() - start
            # 更新完所有参数花费的时间,单位:小时
            need_time = cost_time_per_update * (args.total_steps -
                                                global_step) / 3600.0
            cost_time_per_chectpoint = cost_time_per_update * args.checkpoint_steps / 3600.0
            start = time.time()
            if args.local_rank in [-1, 0]:
                INFO = PRE + '当前/chcklpoint_steps/total:{}/{}/{},loss{}/{},更新一次参数{}秒,checkpoint_steps {}小时,' \
                             '训练完成{}小时\n'.format(global_step, args.checkpoint_steps, args.total_steps,
                                                 round(masked_lm_loss.item(), 5),
                                                 round(next_sentence_loss.item(), 5), round(cost_time_per_update, 4),
                                                 round(cost_time_per_chectpoint, 3), round(need_time, 3))
                print(INFO)
        # Save a trained model
        if (global_step + 1) % args.checkpoint_steps == 0:
            checkpoint_index = (global_step + 1) % args.checkpoint_steps
            if args.rank >= 0:
                train_data_loader.train_sampler.set_epoch(checkpoint_index)
            # if args.eval:
            #     # 如果是pretrain,验证MLM;如果微调,验证评价指标
            #     result = None
            #if best_result < result and _get_checkpont_num(args.model_output_num):
            if args.rank in [0, -1]:
                logger.info("** ** * Saving  model and optimizer * ** **")

                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.model_output_dir,
                    "model.{0}.bin".format(checkpoint_index))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.model_output_dir,
                    "optim.{0}.bin".format(checkpoint_index))
                torch.save(optimizer.state_dict(), output_optim_file)
                if args.fp16 and args.amp:
                    logger.info("** ** * Saving  amp state  * ** **")
                    output_amp_file = os.path.join(
                        args.model_output_dir,
                        "amp.{0}.bin".format(checkpoint_index))
                    torch.save(amp.state_dict(), output_amp_file)
                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
    if args.rank == 0:
        writer.close()
        print('** ** * train finished * ** **')
Beispiel #8
0
def train(config, model, train_iter, dev_iter):
    start_time = time.time()

    if os.path.exists(config.save_path):
        model.load_state_dict(torch.load(config.save_path)['model_state_dict'])

    model.train()

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=config.learning_rate,
                         warmup=0.05,
                         t_total=len(train_iter) * config.num_epochs)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5)

    if os.path.exists(config.save_path):
        optimizer.load_state_dict(
            torch.load(config.save_path)['optimizer_state_dict'])

    total_batch = 0
    dev_best_loss = float('inf')
    dev_last_loss = float('inf')
    no_improve = 0
    flag = False

    model.train()
    # plot_model(model, to_file= config.save_dic+'.png')
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        for i, (trains, labels) in enumerate(train_iter):
            outputs = model(trains)
            model.zero_grad()
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            if total_batch % 100 == 0:
                true = labels.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                train_loss = loss.item()
                dev_acc, dev_loss = evaluate(config, model, dev_iter)
                if dev_loss < dev_best_loss:
                    state = {
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                    }
                    dev_best_loss = dev_loss

                    torch.save(state,
                               config.save_dic + str(total_batch) + '.pth')
                    improve = '*'
                    del state
                else:
                    improve = ''

                if dev_last_loss > dev_loss:
                    no_improve = 0
                elif no_improve % 2 == 0:
                    no_improve += 1
                    scheduler.step()
                else:
                    no_improve += 1

                dev_last_loss = dev_loss

                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(
                    msg.format(total_batch, train_loss, train_acc, dev_loss,
                               dev_acc, time_dif, improve))
                model.train()
            total_batch += 1
            if no_improve > config.require_improvement:
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
Beispiel #9
0
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate)

    if os.path.exists(state_save_path):
        optimizer.load_state_dict(state['opt_state'])

    device = torch.device("cuda")

    tr_total = int(train_dataset.__len__() / args.train_batch_size /
                   args.gradient_accumulation_steps * args.num_train_epochs)
    print_freq = args.print_freq
    eval_freq = len(train_dataloader) // 4
    print('Print freq:', print_freq, "Eval freq:", eval_freq)

    for epoch in range(epoch_start, int(args.num_train_epochs) + 1):
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        with tqdm(total=len(train_dataloader)) as bar:
            for step, batch in enumerate(train_dataloader, start=1):
                model.train()
Beispiel #10
0
def train_eval(args):
    set_seed(args.seed_num)
    train_x_left, train_x_entity, train_x_right, train_y = read_pickle(args.train_data_path)
    valid_x_left, valid_x_entity, valid_x_right, valid_y = read_pickle(args.valid_data_path)

    args.num_labels = train_y.shape[1]
    train_dataloader = get_dataloader(train_y, args.train_batch_size, True, train_x_left, train_x_entity, train_x_right)
    valid_dataloader = get_dataloader(valid_y, args.valid_batch_size, False, valid_x_left, valid_x_entity, valid_x_right)

    if args.model == 'bert':
        return None
        # bert_config = BertConfig(args.bert_config_path)
        # model = NERBert(bert_config, args)
        # model.load_state_dict(torch.load(args.bert_model_path), strict=False)
        # model = NERBert.from_pretrained('bert_chinese',
        #                                 # cache_dir='/home/dutir/yuetianchi/.pytorch_pretrained_bert',
        #                                 num_labels=args.num_labels)
    else:
        if args.embedding:
            word_embedding_matrix = read_pickle(args.embedding_data_path)
            args.vocab_size = len(word_embedding_matrix)
            model = AttentiveLSTM(args, word_embedding_matrix)
        else:
            logging.error("args.embedding should be true")
            return None

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if args.model == 'bert':
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if 'bert' not in n], 'lr': 5e-5, 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and ('bert' in n)],
             'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and ('bert' in n)],
             'weight_decay': 0.0}
        ]
        warmup_proportion = 0.1
        num_train_optimization_steps = int(
            train_samples / args.train_batch_size / args.gradient_accumulation_steps) * args.epochs

        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=warmup_proportion,
                             t_total=num_train_optimization_steps)
    else:
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=current_learning_rate
        )

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    global_step = init_step
    best_score = 0.0

    logging.info('Start Training...')
    logging.info('init_step = %d' % global_step)
    for epoch_id in range(int(args.epochs)):

        train_loss = 0

        for step, train_batch in enumerate(train_dataloader):
            model.train()
            batch = tuple(t.to(device) for t in train_batch)
            train_x = (batch[0], batch[1], batch[2])
            train_y = batch[3]
            loss = model(train_x, train_y)
            if n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            train_loss += loss.item()
            loss.backward()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            if args.do_valid and global_step % args.valid_step == 1:
                true_res = []
                pred_res = []
                valid_losses = []
                model.eval()
                for valid_step, valid_batch in enumerate(valid_dataloader):
                    valid_batch = tuple(t.to(device) for t in valid_batch)
                    valid_x = (valid_batch[0], valid_batch[1], valid_batch[2])
                    valid_y = valid_batch[3]
                    with torch.no_grad():
                        valid_logit = model(valid_x)
                    valid_loss = F.binary_cross_entropy_with_logits(valid_logit, valid_y)
                    valid_logit = F.sigmoid(valid_logit)
                    if args.model == 'bert':
                        # 第一个token是‘cls’
                        valid_losses.append(valid_loss.item())
                        true_res.extend(valid_y.detach().cpu().numpy())
                        pred_res.extend(valid_logit.detach().cpu().numpy())
                    else:
                        valid_losses.append(valid_loss.item())
                        true_res.extend(valid_y.detach().cpu().numpy())
                        pred_res.extend(valid_logit.detach().cpu().numpy())

                metric_res = acc_hook(pred_res, true_res)
                logging.info('Evaluation:step:{},train_loss:{},valid_loss:{},microf1:{},macrof1:{}'.
                             format(str(global_step), train_loss / args.valid_step, np.average(valid_losses),
                                    metric_res['loose_micro_f1'], metric_res['loose_macro_f1']))
                if metric_res['loose_micro_f1'] >= best_score:
                    best_score = metric_res['loose_micro_f1']
                    if args.model == 'bert':
                        model_to_save = model.module if hasattr(model,
                                                                'module') else model  # Only save the model it-self
                        output_dir = '{}_{}'.format('bert', str(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                            output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
                            torch.save(model_to_save.state_dict(), output_model_file)
                            output_config_file = os.path.join(output_dir, CONFIG_NAME)
                            with open(output_config_file, 'w') as f:
                                f.write(model_to_save.config.to_json_string())
                    else:
                        save_variable_list = {
                            'step': global_step,
                            'current_learning_rate': args.learning_rate,
                            'warm_up_steps': step
                        }
                        save_model(model, optimizer, save_variable_list, args)
                train_loss = 0.0
Beispiel #11
0
def main(train_file,
         matched_valid_file,
         mismatched_valid_file,
         test_file,
         bert_path,
         target_dir,
         hidden_size=768,
         dropout=0.5,
         num_classes=2,
         epochs=10,
         batch_size=128,
         learning_rate=5e-5,
         patience=5,
         max_grad_norm=10.0,
         checkpoint=None):
    """
    Train the ESIM model on the text_similarity dataset.

    Args:
        train_file: A path to some preprocessed data that must be used
            to train the model.
        valid_files: A dict containing the paths to the preprocessed matched
            and mismatched datasets that must be used to validate the model.
        embeddings_file: A path to some preprocessed word embeddings that
            must be used to initialise the model.
        target_dir: The path to a directory where the trained model must
            be saved.
        hidden_size: The size of the hidden layers in the model. Defaults
            to 300.
        dropout: The dropout rate to use in the model. Defaults to 0.5.
        num_classes: The number of classes in the output of the model.
            Defaults to 3.
        epochs: The maximum number of epochs for training. Defaults to 64.
        batch_size: The size of the batches for training. Defaults to 32.
        lr: The learning rate for the optimizer. Defaults to 0.0004.
        patience: The patience to use for early stopping. Defaults to 5.
        checkpoint: A checkpoint from which to continue training. If None,
            training starts from scratch. Defaults to None.
    """
    device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")

    print(20 * "=", " Preparing for training ", 20 * "=")

    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    # -------------------- Data loading ------------------- #
    print("\t* Loading training data...")
    with open(train_file, "rb") as pkl:
        train_data = TEXTSIMILARITYDataset(pickle.load(pkl))

    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

    print("\t* Loading validation data...")
    with open(matched_valid_file, "rb") as pkl:
        matched_valid_data = TEXTSIMILARITYDataset(pickle.load(pkl))

    with open(mismatched_valid_file, "rb") as pkl:
        mismatched_valid_data = TEXTSIMILARITYDataset(pickle.load(pkl))        

    with open(test_file, "rb") as pkl:
        test_raw_data = pickle.load(pkl)
        test_data = TEXTSIMILARITYDataset(test_raw_data)              

    matched_valid_loader = DataLoader(matched_valid_data,
                                shuffle=False,
                                batch_size=batch_size)
    mismatched_valid_loader = DataLoader(mismatched_valid_data,
                                shuffle=False,
                                batch_size=batch_size)                                
    test_loader = DataLoader(test_data,
                                shuffle=False,
                                batch_size=batch_size)                                                             

    # -------------------- Model definition ------------------- #
    model = BERT(bert_path,
                 hidden_size,
                 num_classes=num_classes,
                 device=device).to(device)

    # -------------------- Preparation for training  ------------------- #
    criterion = nn.CrossEntropyLoss()
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]    
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=learning_rate,
                         warmup=0.05,
                         t_total=len(train_loader) * epochs)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="max",
                                                           factor=0.5,
                                                           patience=0)

    best_score = 0.0
    start_epoch = 1

    # Data for loss curves plot.
    epochs_count = []
    train_losses = []
    matched_valid_losses = []
    mismatched_valid_losses = []
    test_losses = []

    # Continuing training from a checkpoint if one was given as argument.
    if checkpoint:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint["best_score"]

        print("\t* Training will continue on existing model from epoch {}..."
              .format(start_epoch))

        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epochs_count = checkpoint["epochs_count"]
        train_losses = checkpoint["train_losses"]
        matched_valid_losses = checkpoint["matched_valid_losses"]
        mismatched_valid_losses = checkpoint["mismatched_valid_losses"]
        test_losses = checkpoint["test_losses"]

    # Compute loss and accuracy before starting (or resuming) training.
    _, matched_valid_loss, matched_valid_accuracy, precisions, recalls, f1s = validate(model,
                                             matched_valid_loader,
                                             criterion)
    print("\t* Validation loss before training on matched valid data: {:.4f}, accuracy: {:.4f}%"
          .format(matched_valid_loss, (matched_valid_accuracy*100)))

    _, mismatched_valid_loss, mismatched_valid_accuracy, precisions, recalls, f1s = validate(model,
                                             mismatched_valid_loader,
                                             criterion)
    print("\t* Validation loss before training on mismatched valid data: {:.4f}, accuracy: {:.4f}%"
          .format(mismatched_valid_loss, (mismatched_valid_accuracy*100)))          

    # -------------------- Training epochs ------------------- #
    print("\n",
          20 * "=",
          "Training ESIM model on device: {}".format(device),
          20 * "=")

    patience_counter = 0
    for epoch in range(start_epoch, epochs+1):
        epochs_count.append(epoch)

        print("* Training epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy = train(model,
                                                       train_loader,
                                                       optimizer,
                                                       criterion,
                                                       epoch,
                                                       max_grad_norm)

        train_losses.append(epoch_loss)
        print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%\n"
              .format(epoch_time, epoch_loss, (epoch_accuracy*100)))

        print("* Validation for epoch {} on matched data:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy, precisions, recalls, f1s = validate(model,
                                                          matched_valid_loader,
                                                          criterion)
        matched_valid_losses.append(epoch_loss)
        print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%"
              .format(epoch_time, epoch_loss, (epoch_accuracy*100)))

        print("* Validation for epoch {} on mismatched data:".format(epoch))
        epoch_time, epoch_loss, mis_epoch_accuracy, precisions, recalls, f1s = validate(model,
                                                              mismatched_valid_loader,
                                                              criterion)
        mismatched_valid_losses.append(epoch_loss)
        print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
              .format(epoch_time, epoch_loss, (mis_epoch_accuracy*100)))

        print("* Validation for epoch {} on test data:".format(epoch))
        print("test data size: ", len(test_data))
        epoch_time, epoch_loss, test_epoch_accuracy = test(model,
                                                test_loader,
                                                test_raw_data,
                                                criterion)
        test_losses.append(epoch_loss)
        print("-> Test. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
              .format(epoch_time, epoch_loss, (test_epoch_accuracy*100)))


        # Update the optimizer's learning rate with the scheduler.
        scheduler.step(epoch_accuracy)

        # Early stopping on validation accuracy.
        if epoch_accuracy < best_score:
            patience_counter += 1
        else:
            best_score = epoch_accuracy
            patience_counter = 0
            # Save the best model. The optimizer is not saved to avoid having
            # a checkpoint file that is too heavy to be shared. To resume
            # training from the best model, use the 'esim_*.pth.tar'
            # checkpoints instead.
            torch.save({"epoch": epoch,
                        "model": model.state_dict(),
                        "best_score": best_score,
                        "epochs_count": epochs_count,
                        "train_losses": train_losses,
                        "match_valid_losses": matched_valid_losses,
                        "mismatch_valid_losses": mismatched_valid_losses,
                        "test_losses": test_losses},
                       os.path.join(target_dir, "best.pth.tar"))

        # Save the model at each epoch.
        torch.save({"epoch": epoch,
                    "model": model.state_dict(),
                    "best_score": best_score,
                    "optimizer": optimizer.state_dict(),
                    "epochs_count": epochs_count,
                    "train_losses": train_losses,
                    "match_valid_losses": matched_valid_losses,
                    "mismatch_valid_losses": mismatched_valid_losses,
                    "test_losses": test_losses},
                   os.path.join(target_dir, "esim_{}.pth.tar".format(epoch)))

        if patience_counter >= patience:
            print("-> Early stopping: patience limit reached, stopping...")
            break 

    print("* Validation on test1 data:")
    epoch_time, epoch_loss, test_epoch_accuracy = test(model,   
                                            test_loader, 
                                            test_raw_data,
                                            criterion)
    print("-> Test. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
          .format(epoch_time, epoch_loss, (test_epoch_accuracy*100)))                

    # Plotting of the loss curves for the train and validation sets.
    plt.figure()
    plt.plot(epochs_count, train_losses, "-r")
    plt.plot(epochs_count, matched_valid_losses, "-b")
    plt.plot(epochs_count, mismatched_valid_losses, "-g")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.legend(["Training loss",
                "Validation loss (matched set)",
                "Validation loss (mismatched set)"])
    plt.title("Cross entropy loss")
    plt.savefig("loss.jpg")