Beispiel #1
0
def validate(model, valid_src, valid_tgt, toker, vocab, device, local_rank):
    model.eval()
    val_loss = 0
    n_correct = 0
    n_word = 0
    with open(valid_src, 'r') as src_reader, \
         open(valid_tgt, 'r') as tgt_reader:
        for i, (src, tgt) in enumerate(zip(src_reader, tgt_reader)):
            if local_rank != -1:
                global_rank = torch.distributed.get_rank()
                world_size = torch.distributed.get_world_size()
                if global_rank % world_size != 0:
                    continue
            input_ids, type_ids, mask, labels = convert_raw_input_to_features(
                src, tgt, toker, vocab, device)
            prediction_scores = model(input_ids, type_ids, mask)
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1,
                                                 reduction='sum')
            loss = loss_fct(prediction_scores.squeeze(0), labels.view(-1))
            val_loss += loss.item()
            n_correct += accuracy_count(prediction_scores, labels)
            n_word += (labels != -1).long().sum().item()
    if local_rank != -1:
        val_loss = sum(all_gather_list(val_loss))
        n_correct = sum(all_gather_list(n_correct))
        n_word = sum(all_gather_list(n_word))
    val_loss /= n_word
    acc = n_correct / n_word
    val_log = {'val_loss': val_loss, 'val_acc': acc}
    model.train()
    return val_log
Beispiel #2
0
    def all_gather_stats_list(stat_list, max_size=4096):
        """
        Gather a `Statistics` list accross all processes/nodes

        Args:
            stat_list(list([`Statistics`])): list of statistics objects to
                gather accross all processes/nodes
            max_size(int): max buffer size to use

        Returns:
            our_stats(list([`Statistics`])): list of updated stats
        """
        from torch.distributed import get_rank
        from onmt.utils.distributed import all_gather_list

        # Get a list of world_size lists with len(stat_list) Statistics objects
        all_stats = all_gather_list(stat_list, max_size=max_size)

        our_rank = get_rank()
        our_stats = all_stats[our_rank]
        for other_rank, stats in enumerate(all_stats):
            if other_rank == our_rank:
                continue
            for i, stat in enumerate(stats):
                our_stats[i].update(stat, update_n_src_words=True)
        return our_stats
Beispiel #3
0
    def all_gather_stats_list(stat_list, max_size=4096):
        """
        Gather a `Statistics` list accross all processes/nodes

        Args:
            stat_list(list([`Statistics`])): list of statistics objects to
                gather accross all processes/nodes
            max_size(int): max buffer size to use

        Returns:
            our_stats(list([`Statistics`])): list of updated stats
        """
        from torch.distributed import get_rank
        from onmt.utils.distributed import all_gather_list

        # Get a list of world_size lists with len(stat_list) Statistics objects
        all_stats = all_gather_list(stat_list, max_size=max_size)

        our_rank = get_rank()
        our_stats = all_stats[our_rank]
        for other_rank, stats in enumerate(all_stats):
            if other_rank == our_rank:
                continue
            for i, stat in enumerate(stats):
                our_stats[i].update(stat, update_n_src_words=True)
        return our_stats
Beispiel #4
0
    def all_gather_stats_list(stat_list, max_size=4096):
        """
        Gather a `Statistics` list accross all processes/nodes

        Args:
            stat_list(list([`Statistics`])): list of statistics objects to
                gather accross all processes/nodes
            max_size(int): max buffer size to use

        Returns:
            our_stats(list([`Statistics`])): list of updated stats
        """

        # apply https://github.com/OpenNMT/OpenNMT-py/commit/2a2621d770adb593942d7999a59401aff35d646a
        from torch.distributed import get_rank
        from onmt.utils.distributed import all_gather_list

        # Get a list of world_size lists with len(stat_list) Statistics objects
        all_stats = all_gather_list(stat_list, max_size=max_size)

        our_rank = get_rank()
        our_stats = all_stats[our_rank]
        for other_rank, stats in enumerate(all_stats):
            if other_rank == our_rank:
                continue
            for i, stat in enumerate(stats):
                our_stats[i].update(stat, update_n_src_words=True)
        return our_stats
 def _norm(self, batch):
     if self._norm_method == "tokens":
         norm = batch.tgt[1:].ne(self._train_loss.padding_idx).sum()
     else:
         norm = batch.batch_size
     if self.n_gpu > 1:
         norm = sum(all_gather_list(norm))
     return norm
Beispiel #6
0
def main(opts):
    if opts.local_rank == -1:
        assert torch.cuda.is_available()
        device = torch.device("cuda")
        n_gpu = 1
    else:
        torch.cuda.set_device(opts.local_rank)
        device = torch.device("cuda", opts.local_rank)
        # Initializes the distributed backend which will take care of
        # sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
        n_gpu = torch.distributed.get_world_size()
        logger.info("device: {} n_gpu: {}, distributed training: {}, "
                    "16-bits training: {}".format(
                        device, n_gpu, bool(opts.local_rank != -1), opts.fp16))
    opts.n_gpu = n_gpu

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                            opts.gradient_accumulation_steps))

    is_master = opts.local_rank == -1 or torch.distributed.get_rank() == 0

    if is_master:
        save_training_meta(opts)

    random.seed(opts.seed)
    np.random.seed(opts.seed)
    torch.manual_seed(opts.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(opts.seed)

    tokenizer = BertTokenizer.from_pretrained(
        opts.bert_model, do_lower_case='uncased' in opts.bert_model)

    # train_examples = None
    print("Loading Train Dataset", opts.train_file)
    vocab_dump = torch.load(opts.vocab_file)
    vocab = vocab_dump['tgt'].fields[0][1].vocab.stoi
    train_dataset = BertDataset(opts.train_file, tokenizer, vocab,
                                seq_len=opts.max_seq_length,
                                max_len=opts.max_sent_length)

    # Prepare model
    model = BertForSeq2seq.from_pretrained(opts.bert_model)
    embedding = convert_embedding(
        tokenizer, vocab, model.bert.embeddings.word_embeddings.weight)
    model.update_output_layer(embedding)
    if opts.fp16:
        model.half()
    model.to(device)
    if opts.local_rank != -1:
        # need to make sure models are the same in the beginning
        params = [p.data for p in model.parameters()]
        broadcast_tensors(params)
    for name, module in model.named_modules():
        # we might want to tune dropout for smaller dataset
        if isinstance(module, torch.nn.Dropout):
            module.p = opts.dropout

    # Prepare optimizer
    param_optimizer = [(n, p) for n, p in model.named_parameters()
                       if 'pooler' not in n]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)],
         'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]

    if opts.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from "
                              "https://www.github.com/nvidia/apex "
                              "to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=opts.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if opts.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=opts.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=opts.learning_rate,
                             warmup=opts.warmup_proportion,
                             t_total=opts.num_train_steps)

    global_step = 0
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Batch size = %d", opts.train_batch_size)
    logger.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    logger.info("  Num steps = %d", opts.num_train_steps)

    if opts.local_rank == -1:
        train_sampler = TokenBucketSampler(
            train_dataset.lens,
            bucket_size=8192,
            batch_size=opts.train_batch_size,
            droplast=True)
        train_dataloader = DataLoader(train_dataset,
                                      batch_sampler=train_sampler,
                                      num_workers=4,
                                      collate_fn=BertDataset.pad_collate)
    else:
        train_sampler = DistributedTokenBucketSampler(
            n_gpu, opts.local_rank,
            train_dataset.lens,
            bucket_size=8192,
            batch_size=opts.train_batch_size,
            droplast=True)
        train_dataloader = DataLoader(train_dataset,
                                      batch_sampler=train_sampler,
                                      num_workers=4,
                                      collate_fn=BertDataset.pad_collate)

    if is_master:
        TB_LOGGER.create(join(opts.output_dir, 'log'))
    running_loss = RunningMeter('loss')
    model.train()
    if is_master:
        pbar = tqdm(total=opts.num_train_steps)
    else:
        logger.disabled = True
        pbar = None
    n_examples = 0
    n_epoch = 0
    start = time()
    while True:
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) if t is not None else t for t in batch)
            input_ids, input_mask, segment_ids, lm_label_ids = batch
            n_examples += input_ids.size(0)
            mask = lm_label_ids != -1
            loss = model(input_ids, segment_ids, input_mask,
                         lm_label_ids, mask, True)
            if opts.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()
            running_loss(loss.item())
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1
                if opts.fp16:
                    # modify learning rate with special warm up BERT uses
                    # if opts.fp16 is False, BertAdam is used that handles
                    # this automatically
                    lr_this_step = opts.learning_rate * warmup_linear(
                        global_step/opts.num_train_steps,
                        opts.warmup_proportion)
                    if lr_this_step < 0:
                        # save guard for possible miscalculation of train steps
                        lr_this_step = 1e-8
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    TB_LOGGER.add_scalar('lr',
                                         lr_this_step, global_step)

                # NOTE running loss not gathered across GPUs for speed
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                if opts.local_rank != -1:
                    # gather gradients from every processes
                    grads = [p.grad.data for p in model.parameters()
                             if p.requires_grad and p.grad is not None]
                    all_reduce_and_rescale_tensors(grads, float(1))
                optimizer.step()
                optimizer.zero_grad()
                if pbar is not None:
                    pbar.update(1)
                if global_step % 5 == 0:
                    torch.cuda.empty_cache()
                if global_step % 100 == 0:
                    if opts.local_rank != -1:
                        total = sum(all_gather_list(n_examples))
                    else:
                        total = n_examples
                    if is_master:
                        ex_per_sec = int(total / (time()-start))
                        logger.info(f'{total} examples trained at '
                                    f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('ex_per_s', ex_per_sec, global_step)

                if global_step % opts.valid_steps == 0:
                    logger.info(f"start validation at Step {global_step}")
                    with torch.no_grad():
                        val_log = validate(model,
                                           opts.valid_src, opts.valid_tgt,
                                           tokenizer, vocab, device,
                                           opts.local_rank)
                    logger.info(f"Val Acc: {val_log['val_acc']}; "
                                f"Val Loss: {val_log['val_loss']}")
                    TB_LOGGER.log_scaler_dict(val_log)
                    if is_master:
                        output_model_file = join(
                            opts.output_dir, 'ckpt',
                            f"model_step_{global_step}.pt")
                        # save cpu checkpoint
                        state_dict = {k: v.cpu() if isinstance(v, torch.Tensor)
                                      else v
                                      for k, v in model.state_dict().items()}
                        torch.save(state_dict, output_model_file)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        if is_master:
            logger.info(f"finished {n_epoch} epochs")
    if opts.num_train_steps % opts.valid_steps != 0:
        with torch.no_grad():
            val_log = validate(model, opts.valid_src, opts.valid_tgt,
                               tokenizer, vocab, device, opts.local_rank)
        TB_LOGGER.log_scaler_dict(val_log)
        if is_master:
            output_model_file = join(opts.output_dir, 'ckpt',
                                     f"model_step_{global_step}.pt")
            # save cpu checkpoint
            state_dict = {k: v.cpu() if isinstance(v, torch.Tensor)
                          else v
                          for k, v in model.state_dict().items()}
            torch.save(model.state_dict(), output_model_file)
    def train(self,
              train_iter_fct,
              train_steps,
              valid_iter_fct=None,
              valid_steps=-1):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter_fct`

        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter_fct(function): same as train_iter_fct, for valid data
            train_steps(int):
            valid_steps(int):
            save_checkpoint_steps(int):

        Return:
            None
        """

        step = self.optim.training_step + 1
        true_batchs = []
        accum = 0
        normalization = 0
        train_iter = train_iter_fct()

        total_stats, report_stats = Statistics(), Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:

            reduce_counter = 0
            for i, batch in enumerate(train_iter):

                # Pick batch for this GPU if applicable
                if self._should_skip_batch(i): continue

                _, batch_size = batch
                true_batchs.append(batch)
                normalization += batch_size
                accum += 1
                if accum == self.grad_accum_count:
                    reduce_counter += 1
                    if self.n_gpu > 1:
                        normalization = sum(
                            distributed.all_gather_list(normalization))

                    self._gradient_accumulation(true_batchs, normalization,
                                                total_stats, report_stats,
                                                step)

                    report_stats = self._maybe_report_training(
                        step, train_steps, self.optim.learning_rate(),
                        report_stats)

                    true_batchs = []
                    accum = 0
                    normalization = 0

                    # Save checkpoint if applicable on one GPU
                    if (step % self.save_checkpoint_steps == 0
                            and self.gpu_rank == 0):
                        self._save(step)

                    step += 1

                    # Break loop if hit training step limit
                    if step > train_steps: break

            # Reset dataset iterator if necessary
            train_iter = train_iter_fct()

        return total_stats