示例#1
0
def main():
    args = get_config()

    world_size = flow.env.get_world_size()
    if args.train_global_batch_size is None:
        args.train_global_batch_size = args.train_batch_size * world_size
    else:
        assert args.train_global_batch_size % args.train_batch_size == 0

    if args.val_global_batch_size is None:
        args.val_global_batch_size = args.val_batch_size * world_size
    else:
        assert args.val_global_batch_size % args.val_batch_size == 0

    flow.boxing.nccl.set_fusion_threshold_mbytes(args.nccl_fusion_threshold_mb)
    flow.boxing.nccl.set_fusion_max_ops_num(args.nccl_fusion_max_ops)

    if args.with_cuda:
        device = "cuda"
    else:
        device = "cpu"

    print("Device is: ", device)

    print("Creating Dataloader")
    train_data_loader = OfRecordDataLoader(
        ofrecord_dir=args.ofrecord_path,
        mode="train",
        dataset_size=args.train_dataset_size,
        batch_size=args.train_global_batch_size,
        data_part_num=args.train_data_part,
        seq_length=args.seq_length,
        max_predictions_per_seq=args.max_predictions_per_seq,
        consistent=args.use_consistent,
    )

    test_data_loader = OfRecordDataLoader(
        ofrecord_dir=args.ofrecord_path,
        mode="test",
        dataset_size=1024,
        batch_size=args.val_global_batch_size,
        data_part_num=4,
        seq_length=args.seq_length,
        max_predictions_per_seq=args.max_predictions_per_seq,
        consistent=args.use_consistent,
    )

    print("Building BERT Model")
    hidden_size = 64 * args.num_attention_heads
    intermediate_size = 4 * hidden_size
    bert_model = BertForPreTraining(
        args.vocab_size,
        args.seq_length,
        hidden_size,
        args.num_hidden_layers,
        args.num_attention_heads,
        intermediate_size,
        nn.GELU(),
        args.hidden_dropout_prob,
        args.attention_probs_dropout_prob,
        args.max_position_embeddings,
        args.type_vocab_size,
    )

    # Load the same initial parameters with lazy model.
    # from utils.compare_lazy_outputs import load_params_from_lazy
    # load_params_from_lazy(
    #     bert_model.state_dict(),
    #     "../../OneFlow-Benchmark/LanguageModeling/BERT/initial_model",
    # )

    assert id(bert_model.cls.predictions.decoder.weight) == id(
        bert_model.bert.embeddings.word_embeddings.weight
    )

    ns_criterion = nn.CrossEntropyLoss(reduction="mean")
    mlm_criterion = nn.CrossEntropyLoss(reduction="none")

    if args.use_consistent:
        placement = flow.env.all_device_placement("cuda")
        bert_model = bert_model.to_consistent(
            placement=placement, sbp=flow.sbp.broadcast
        )
    else:
        bert_model.to(device)
        ns_criterion.to(device)
        mlm_criterion.to(device)

    optimizer = build_optimizer(
        args.optim_name,
        bert_model,
        args.lr,
        args.weight_decay,
        weight_decay_excludes=["bias", "LayerNorm", "layer_norm"],
        clip_grad_max_norm=1,
        clip_grad_norm_type=2.0,
    )

    steps = args.epochs * len(train_data_loader)
    warmup_steps = int(steps * args.warmup_proportion)

    lr_scheduler = PolynomialLR(optimizer, steps=steps, end_learning_rate=0.0)

    lr_scheduler = flow.optim.lr_scheduler.WarmUpLR(
        lr_scheduler, warmup_factor=0, warmup_iters=warmup_steps, warmup_method="linear"
    )

    def get_masked_lm_loss(
        logit, masked_lm_labels, label_weights, max_predictions_per_seq,
    ):

        label_id = flow.reshape(masked_lm_labels, [-1])

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        pre_example_loss = mlm_criterion(logit, label_id)
        pre_example_loss = flow.reshape(pre_example_loss, [-1, max_predictions_per_seq])
        numerator = flow.sum(pre_example_loss * label_weights)
        denominator = flow.sum(label_weights) + 1e-5
        loss = numerator / denominator
        return loss

    class BertGraph(nn.Graph):
        def __init__(self):
            super().__init__()
            self.bert = bert_model
            self.ns_criterion = ns_criterion
            self.masked_lm_criterion = partial(
                get_masked_lm_loss, max_predictions_per_seq=args.max_predictions_per_seq
            )
            self.add_optimizer(optimizer, lr_sch=lr_scheduler)
            self._train_data_loader = train_data_loader
            if args.grad_acc_steps > 1:
                self.config.set_gradient_accumulation_steps(args.grad_acc_steps)
            if args.use_fp16:
                self.config.enable_amp(True)
                grad_scaler = flow.amp.GradScaler(
                    init_scale=2 ** 30,
                    growth_factor=2.0,
                    backoff_factor=0.5,
                    growth_interval=2000,
                )
                self.set_grad_scaler(grad_scaler)
            self.config.allow_fuse_add_to_output(True)
            self.config.allow_fuse_model_update_ops(True)

        def build(self):

            (
                input_ids,
                next_sentence_labels,
                input_mask,
                segment_ids,
                masked_lm_ids,
                masked_lm_positions,
                masked_lm_weights,
            ) = self._train_data_loader()
            input_ids = input_ids.to(device=device)
            input_mask = input_mask.to(device=device)
            segment_ids = segment_ids.to(device=device)
            next_sentence_labels = next_sentence_labels.to(device=device)
            masked_lm_ids = masked_lm_ids.to(device=device)
            masked_lm_positions = masked_lm_positions.to(device=device)
            masked_lm_weights = masked_lm_weights.to(device=device)

            # 1. forward the next_sentence_prediction and masked_lm model
            prediction_scores, seq_relationship_scores = self.bert(
                input_ids, segment_ids, input_mask, masked_lm_positions
            )

            # 2-1. loss of is_next classification result
            next_sentence_loss = self.ns_criterion(
                seq_relationship_scores.reshape(-1, 2), next_sentence_labels.reshape(-1)
            )

            masked_lm_loss = self.masked_lm_criterion(
                prediction_scores, masked_lm_ids, masked_lm_weights
            )

            total_loss = masked_lm_loss + next_sentence_loss

            total_loss.backward()
            return (
                seq_relationship_scores,
                next_sentence_labels,
                total_loss,
                masked_lm_loss,
                next_sentence_loss,
            )

    bert_graph = BertGraph()

    class BertEvalGraph(nn.Graph):
        def __init__(self):
            super().__init__()
            self.bert = bert_model
            self._test_data_loader = test_data_loader
            self.config.allow_fuse_add_to_output(True)

        def build(self):
            (
                input_ids,
                next_sent_labels,
                input_masks,
                segment_ids,
                masked_lm_ids,
                masked_lm_positions,
                masked_lm_weights,
            ) = self._test_data_loader()
            input_ids = input_ids.to(device=device)
            input_masks = input_masks.to(device=device)
            segment_ids = segment_ids.to(device=device)
            next_sent_labels = next_sent_labels.to(device=device)
            masked_lm_ids = masked_lm_ids.to(device=device)
            masked_lm_positions = masked_lm_positions.to(device)

            with flow.no_grad():
                # 1. forward the next_sentence_prediction and masked_lm model
                _, seq_relationship_scores = self.bert(
                    input_ids, input_masks, segment_ids
                )

            return seq_relationship_scores, next_sent_labels

    bert_eval_graph = BertEvalGraph()

    train_total_losses = []

    for epoch in range(args.epochs):
        metric = Metric(
            desc="bert pretrain",
            print_steps=args.loss_print_every_n_iters,
            batch_size=args.train_global_batch_size * args.grad_acc_steps,
            keys=["total_loss", "mlm_loss", "nsp_loss", "pred_acc"],
        )

        # Train
        bert_model.train()

        for step in range(len(train_data_loader)):
            bert_outputs = pretrain(bert_graph, args.metric_local)

            if flow.env.get_rank() == 0:
                metric.metric_cb(step, epoch=epoch)(bert_outputs)

            train_total_losses.append(bert_outputs["total_loss"])

    # Eval
    bert_model.eval()
    val_acc = validation(
        epoch,
        len(test_data_loader),
        bert_eval_graph,
        args.val_print_every_n_iters,
        args.metric_local,
    )

    save_model(bert_model, args.checkpoint_path, epoch, val_acc, args.use_consistent)
def main():

    args = get_config()

    if args.with_cuda:
        device = flow.device("cuda")
    else:
        device = flow.device("cpu")

    print("Creating Dataloader")
    train_data_loader = OfRecordDataLoader(
        ofrecord_dir=args.ofrecord_path,
        mode="train",
        dataset_size=args.train_dataset_size,
        batch_size=args.train_batch_size,
        data_part_num=args.train_data_part,
        seq_length=args.seq_length,
        max_predictions_per_seq=args.max_predictions_per_seq,
        consistent=False,
    )

    test_data_loader = OfRecordDataLoader(
        ofrecord_dir=args.ofrecord_path,
        mode="test",
        dataset_size=1024,
        batch_size=args.val_batch_size,
        data_part_num=4,
        seq_length=args.seq_length,
        max_predictions_per_seq=args.max_predictions_per_seq,
        consistent=False,
    )

    print("Building BERT Model")
    hidden_size = 64 * args.num_attention_heads
    intermediate_size = 4 * hidden_size
    bert_model = BertForPreTraining(
        args.vocab_size,
        args.seq_length,
        hidden_size,
        args.num_hidden_layers,
        args.num_attention_heads,
        intermediate_size,
        nn.GELU(),
        args.hidden_dropout_prob,
        args.attention_probs_dropout_prob,
        args.max_position_embeddings,
        args.type_vocab_size,
    )

    # Load the same initial parameters with lazy model.
    # from utils.compare_lazy_outputs import load_params_from_lazy
    # load_params_from_lazy(
    #     bert_model.state_dict(),
    #     "../../OneFlow-Benchmark/LanguageModeling/BERT/initial_model",
    # )

    bert_model = bert_model.to(device)
    if args.use_ddp:
        bert_model = ddp(bert_model)

    optimizer = build_optimizer(
        args.optim_name,
        bert_model,
        args.lr,
        args.weight_decay,
        weight_decay_excludes=["bias", "LayerNorm", "layer_norm"],
        clip_grad_max_norm=1,
        clip_grad_norm_type=2.0,
    )

    steps = args.epochs * len(train_data_loader)
    warmup_steps = int(steps * args.warmup_proportion)

    lr_scheduler = PolynomialLR(optimizer, steps=steps, end_learning_rate=0.0)

    lr_scheduler = flow.optim.lr_scheduler.WarmUpLR(lr_scheduler,
                                                    warmup_factor=0,
                                                    warmup_iters=warmup_steps,
                                                    warmup_method="linear")

    ns_criterion = nn.CrossEntropyLoss(reduction="mean")
    mlm_criterion = nn.CrossEntropyLoss(reduction="none")

    def get_masked_lm_loss(
        logit_blob,
        masked_lm_positions,
        masked_lm_labels,
        label_weights,
        max_prediction_per_seq,
    ):
        # gather valid position indices
        logit_blob = flow.gather(
            logit_blob,
            index=masked_lm_positions.unsqueeze(2).repeat(
                1, 1, args.vocab_size),
            dim=1,
        )

        logit_blob = flow.reshape(logit_blob, [-1, args.vocab_size])
        label_id_blob = flow.reshape(masked_lm_labels, [-1])

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        pre_example_loss = mlm_criterion(logit_blob, label_id_blob)
        pre_example_loss = flow.reshape(pre_example_loss,
                                        [-1, max_prediction_per_seq])
        numerator = flow.sum(pre_example_loss * label_weights)
        denominator = flow.sum(label_weights) + 1e-5
        loss = numerator / denominator
        return loss

    train_total_losses = []
    for epoch in range(args.epochs):
        metric = Metric(
            desc="bert pretrain",
            print_steps=args.loss_print_every_n_iters,
            batch_size=args.train_batch_size,
            keys=["total_loss", "mlm_loss", "nsp_loss", "pred_acc"],
        )

        # Train
        bert_model.train()

        for step in range(len(train_data_loader)):
            bert_outputs = pretrain(
                train_data_loader,
                bert_model,
                ns_criterion,
                partial(
                    get_masked_lm_loss,
                    max_prediction_per_seq=args.max_predictions_per_seq,
                ),
                optimizer,
                lr_scheduler,
            )

            if flow.env.get_rank() == 0:
                metric.metric_cb(step, epoch=epoch)(bert_outputs)

            train_total_losses.append(bert_outputs["total_loss"])

        # Eval
        bert_model.eval()
        val_acc = validation(epoch, test_data_loader, bert_model,
                             args.val_print_every_n_iters)

        save_model(bert_model, args.checkpoint_path, epoch, val_acc, False)
示例#3
0
def main():

    hidden_size = 64 * args.num_attention_heads  # H = 64, size per head
    intermediate_size = hidden_size * 4

    print("Create Bert model for SQuAD")
    squad_model = SQuAD(
        args.vocab_size,
        seq_length=args.seq_length,
        hidden_size=hidden_size,
        hidden_layers=args.num_hidden_layers,
        atten_heads=args.num_attention_heads,
        intermediate_size=intermediate_size,
        hidden_act=nn.GELU(),
        hidden_dropout_prob=args.hidden_dropout_prob,
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        max_position_embeddings=args.max_position_embeddings,
        type_vocab_size=args.type_vocab_size,
        initializer_range=0.02,
    )

    # Load pretrain model from lazy trained model
    load_params_from_lazy(squad_model.state_dict(), args.model_load_dir)

    squad_model.to(device)

    if args.do_train:
        print("Create SQuAD training data decoders")
        train_decoders = SquadDecoder(
            args.train_data_dir, batch_size, args.train_data_part_num, args.seq_length
        )

        optimizer = build_adamW_optimizer(
            squad_model,
            args.learning_rate,
            args.weight_decay_rate,
            weight_decay_excludes=["bias", "LayerNorm", "layer_norm"],
        )

        lr_scheduler = PolynomialLR(
            optimizer, steps=args.iter_num, end_learning_rate=0.0
        )

        warmup_batches = int(args.iter_num * args.warmup_proportion)
        lr_scheduler = flow.optim.lr_scheduler.WarmUpLR(
            lr_scheduler,
            warmup_factor=0,
            warmup_iters=warmup_batches,
            warmup_method="linear",
        )

        class SQuADGraph(nn.Graph):
            def __init__(self):
                super().__init__()
                self.squad_model = squad_model
                self.criterion = nn.CrossEntropyLoss()

                self.add_optimizer(optimizer, lr_sch=lr_scheduler)
                self._decoders = train_decoders
                if args.use_fp16:
                    self.config.enable_amp(True)
                    grad_scaler = flow.amp.GradScaler(
                        init_scale=2 ** 30,
                        growth_factor=2.0,
                        backoff_factor=0.5,
                        growth_interval=2000,
                    )
                    self.set_grad_scaler(grad_scaler)

            def build(self):
                (
                    input_ids,
                    input_mask,
                    segment_ids,
                    start_positions,
                    end_positions,
                ) = self._decoders()
                input_ids = input_ids.to(device=device)
                input_mask = input_mask.to(device=device)
                segment_ids = segment_ids.to(device=device)
                start_positions = start_positions.to(device=device)
                end_positions = end_positions.to(device=device)

                start_logits, end_logits = self.squad_model(
                    input_ids, segment_ids, input_mask
                )
                start_logits = flow.reshape(start_logits, [-1, args.seq_length])
                end_logits = flow.reshape(end_logits, [-1, args.seq_length])

                start_loss = self.criterion(start_logits, start_positions.squeeze(1))
                end_loss = self.criterion(end_logits, end_positions.squeeze(1))
                total_loss = (start_loss + end_loss) * 0.5
                total_loss.backward()

                return total_loss

        squad_graph = SQuADGraph()

        for epoch in range(args.num_epochs):
            squad_model.train()

            metric = Metric(
                desc="train",
                print_steps=args.loss_print_every_n_iter,
                batch_size=batch_size,
                keys=["total_loss"],
            )
            for step in range(epoch_size):
                metric.metric_cb(step, epoch=epoch)(squad_finetune(squad_graph))

        if args.save_last_snapshot:
            save_model(squad_model, args.model_save_dir, "last_snapshot")

    if args.do_eval:
        assert os.path.isdir(args.eval_data_dir)
        print("Create SQuAD testing data decoders")
        test_decoders = SquadDecoder(
            args.eval_data_dir,
            eval_batch_size,
            args.eval_data_part_num,
            args.seq_length,
            is_train=False,
        )

        squad_model.eval()

        class SQuADEvalGraph(nn.Graph):
            def __init__(self):
                super().__init__()
                self.squad_model = squad_model
                self._decoders = test_decoders

            def build(self):
                (input_ids, input_mask, segment_ids, unique_ids) = self._decoders()
                input_ids = input_ids.to(device=device)
                input_mask = input_mask.to(device=device)
                segment_ids = segment_ids.to(device=device)
                unique_ids = unique_ids.to(device=device)

                with flow.no_grad():
                    start_logits, end_logits = self.squad_model(
                        input_ids, segment_ids, input_mask
                    )

                return unique_ids, start_logits, end_logits

        squad_eval_graph = SQuADEvalGraph()

        squad_eval(num_eval_steps, squad_eval_graph, args.loss_print_every_n_iter)