def build_model_optimizer(self, Optimizer="adam"):
        hcg = fleet.get_hybrid_communicate_group()
        word_size = hcg.get_model_parallel_world_size()
        sharding_id = hcg.get_sharding_parallel_rank()
        dp_id = hcg.get_data_parallel_rank()
        rank_id = dist.get_rank()

        np_fc1 = np.random.random_sample((hidden_size, inner_size))
        np_fc2 = np.random.random_sample((inner_size, hidden_size))

        model_a = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size,
                              np_fc1, np_fc2)
        optimizer_a = self.build_optimizer(model_a,
                                           strategy=self.strategy,
                                           is_sharding=True,
                                           Optimizer=Optimizer)
        model_a = fleet.distributed_model(model_a)
        optimizer_a = fleet.distributed_optimizer(optimizer_a)

        model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size,
                              np_fc1, np_fc2)
        optimizer_b = self.build_optimizer(model_b,
                                           strategy=self.strategy,
                                           is_sharding=False,
                                           Optimizer=Optimizer)

        return model_a, optimizer_a, model_b, optimizer_b
    def test_pp_model(self):
        hcg = fleet.get_hybrid_communicate_group()
        word_size = hcg.get_model_parallel_world_size()
        dp_id = hcg.get_data_parallel_rank()
        pp_id = hcg.get_stage_id()
        rank_id = dist.get_rank()
        topology = hcg.topology()
        set_random_seed(1024, dp_id, rank_id)

        model = ModelPipe(topology)
        scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2],
                                                       values=[0.001, 0.002],
                                                       verbose=True)
        optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
                                         parameters=model.parameters())

        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

        for step_id in range(5):
            x_data = np.random.randint(0,
                                       vocab_size,
                                       size=[batch_size, length])
            x = paddle.to_tensor(x_data)
            x.stop_gradient = True

            e_loss = model.eval_batch([x, x], True)
            loss = model.train_batch([x, x], optimizer, scheduler)

            # TODO(shenliang03) add utest for loss
            if pp_id != 0:
                np.testing.assert_allclose(loss.numpy(), e_loss.numpy())
    def test_pp_model(self):
        hcg = fleet.get_hybrid_communicate_group()
        word_size = hcg.get_model_parallel_world_size()
        dp_id = hcg.get_data_parallel_rank()
        pp_id = hcg.get_stage_id()
        rank_id = dist.get_rank()
        topology = hcg.topology()
        set_random_seed(1024, dp_id, rank_id)

        model = ModelPipe(topology)
        scheduler = paddle.optimizer.lr.PiecewiseDecay(
            boundaries=[2], values=[0.001, 0.002], verbose=True)
        optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
                                         parameters=model.parameters())

        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)
        output_dir = tempfile.mkdtemp()

        # warmup step
        for step_id in range(2):
            x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
            x = paddle.to_tensor(x_data)
            x.stop_gradient = True
            loss = model.train_batch([x, x], optimizer, scheduler)

        model._layers.save_state_dict(output_dir)
        paddle.save(optimizer.state_dict(),
                    os.path.join(output_dir, "model_state.pdopt"))

        # construct data
        test_steps = 5
        np_data = np.random.randint(
            0, vocab_size, size=[test_steps, batch_size, length])

        origin_loss = []
        for step_id in range(5):
            x_data = np_data[step_id, :]
            x = paddle.to_tensor(x_data)
            x.stop_gradient = True
            loss = model.train_batch([x, x], optimizer, scheduler)
            origin_loss.append(loss.numpy())

        # test step
        model._layers.set_state_dir(output_dir)
        opt_dict = paddle.load(os.path.join(output_dir, "model_state.pdopt"))
        optimizer.set_state_dict(opt_dict)

        for step_id in range(5):
            x_data = np_data[step_id, :]
            x = paddle.to_tensor(x_data)
            x.stop_gradient = True
            loss = model.train_batch([x, x], optimizer, scheduler)
            print("origin loss: ", origin_loss[step_id], "current loss: ",
                  loss.numpy())
            np.testing.assert_allclose(loss.numpy(), origin_loss[step_id])

        # finally, remove the model/optimizer path
        shutil.rmtree(output_dir)
    def test_pp_model(self):
        hcg = fleet.get_hybrid_communicate_group()
        word_size = hcg.get_model_parallel_world_size()
        dp_id = hcg.get_data_parallel_rank()
        pp_id = hcg.get_stage_id()
        rank_id = dist.get_rank()
        set_random_seed(1024, dp_id, rank_id)

        #construct model a
        model_a = AlexNet(10)
        scheduler_a, optimizer_a = self.build_optimizer(model_a)

        param_len = len(model_a.parameters())

        parameters = []
        for param in model_a.parameters():
            parameters.append(param.numpy())

        # construct model b
        model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
        scheduler_b, optimizer_b = self.build_optimizer(model_b)
        model_b = fleet.distributed_model(model_b)
        optimizer_b = fleet.distributed_optimizer(optimizer_b)

        for idx, param in enumerate(model_b.parameters()):
            param.set_value(parameters[idx + pp_id * (param_len // 2)])

        # construct reader
        train_reader = paddle.batch(paddle.dataset.mnist.train(),
                                    batch_size=batch_size,
                                    drop_last=True)

        for step_id, data in enumerate(train_reader()):
            x_data = np.array([x[0] for x in data]).astype('float32').reshape(
                batch_size, 1, 28, 28)
            y_data = np.array([x[1] for x in data
                               ]).astype('int64').reshape(batch_size, 1)
            img = paddle.to_tensor(x_data)
            label = paddle.to_tensor(y_data)
            img.stop_gradient = True
            label.stop_gradient = True

            if step_id >= 5:
                return True

            loss_a = model_a(img, label)
            loss_a.backward()
            optimizer_a.step()
            optimizer_a.clear_grad()
            scheduler_a.step()

            loss_b = model_b.train_batch([img, label], optimizer_b,
                                         scheduler_b)

            print("loss: ", loss_a.numpy(), loss_b.numpy())
            np.testing.assert_allclose(loss_a.numpy(),
                                       loss_b.numpy(),
                                       rtol=5e-5)
Esempio n. 5
0
def wrap_sharding_2_3(model, optimizer, scaler, sharding_offload):
    group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group()
    level = "p_g_os" if args.sharding_stage == 3 else "os_g"
    return group_sharded_parallel(model=model,
                                  optimizer=optimizer,
                                  level=level,
                                  scaler=scaler,
                                  group=group,
                                  offload=sharding_offload)
 def setUp(self):
     strategy = fleet.DistributedStrategy()
     self.pipeline_parallel_size = 2
     strategy.hybrid_configs = {
         "dp_degree": 1,
         "mp_degree": 1,
         "pp_degree": self.pipeline_parallel_size
     }
     fleet.init(is_collective=True, strategy=strategy)
     self.hcg = fleet.get_hybrid_communicate_group()
    def test_row_parallel_layer(self):
        global_dtype = "float32"
        paddle.set_default_dtype(global_dtype)
        set_random_seed(1024)

        self.hcg = fleet.get_hybrid_communicate_group()

        self.word_size = self.hcg.get_model_parallel_world_size()
        self.rank_id = self.hcg.get_model_parallel_rank()

        input_size_per_card = 11
        input_size = input_size_per_card * self.model_parallel_size
        output_size_per_card = 10
        output_size = output_size_per_card * self.model_parallel_size
        batch_size = 4

        model_a = RowLinearNet(input_size, output_size)

        # get w
        check_group = dist.new_group(list(range(self.model_parallel_size)))
        integral_w = []
        partial_w = model_a.parallel_linear.weight.clone().detach()
        paddle.distributed.all_gather(integral_w, partial_w, group=check_group)
        integral_w = paddle.concat(integral_w, axis=0)

        model_b = SimpleMatmul(integral_w, output_size, global_dtype)

        optimizer_a = paddle.optimizer.SGD(learning_rate=0.001,
                                           parameters=model_a.parameters())

        optimizer_b = paddle.optimizer.SGD(learning_rate=0.001,
                                           parameters=model_b.parameters())

        for idx in range(5):
            input = paddle.randn([batch_size, input_size], global_dtype)
            input.stop_gradient = True

            output_a = model_a(input)
            loss_a = output_a.mean()
            loss_a.backward()

            output_b = model_b(input)
            loss_b = output_b.mean()
            loss_b.backward()

            optimizer_a.step()
            optimizer_b.step()

            np.testing.assert_allclose(
                loss_a.numpy(), loss_b.numpy(), rtol=5e-6)
 def build_optimizer(self,
                     model,
                     strategy=None,
                     is_sharding=True,
                     Optimizer="adam"):
     clip = paddle.nn.ClipGradByGlobalNorm(0.5)
     if Optimizer == "adam":
         if is_sharding:
             optimizer = DygraphShardingOptimizer(
                 hcg=fleet.get_hybrid_communicate_group(),
                 user_defined_strategy=strategy,
                 params=model.parameters(),
                 inner_optimizer_class=paddle.optimizer.AdamW,
                 learning_rate=0.001,
                 weight_decay=0.00001,
                 grad_clip=clip)
         else:
             optimizer = paddle.optimizer.AdamW(
                 parameters=model.parameters(),
                 learning_rate=0.001,
                 weight_decay=0.00001,
                 grad_clip=clip)
     else:
         if is_sharding:
             optimizer = DygraphShardingOptimizer(
                 hcg=fleet.get_hybrid_communicate_group(),
                 user_defined_strategy=strategy,
                 params=model.parameters(),
                 inner_optimizer_class=paddle.optimizer.Momentum,
                 learning_rate=0.001,
                 grad_clip=clip)
         else:
             optimizer = paddle.optimizer.Momentum(
                 learning_rate=0.001,
                 parameters=model.parameters(),
                 grad_clip=clip)
     return optimizer
Esempio n. 9
0
def model_parallel_random_seed(seed=None):
    import paddle.distributed.fleet as fleet
    hcg = fleet.get_hybrid_communicate_group()
    rank = hcg.get_model_parallel_rank()

    if seed:
        global_seed = seed
        local_seed = seed * 1024 + rank * 100
    else:
        global_seed = np.random.randint(0, 655350)
        local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)

    RNG_STATE_TRACKER.reset()
    RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
    paddle.seed(global_seed)
Esempio n. 10
0
    def forward(self, prediction_scores, masked_lm_labels, loss_mask):

        hcg = fleet.get_hybrid_communicate_group()
        mp_size = hcg.get_model_parallel_world_size()
        if mp_size > 1:
            masked_lm_loss = self.parallel_loss_func(
                prediction_scores, masked_lm_labels.unsqueeze(2))
        else:
            masked_lm_loss = self.loss_func(prediction_scores,
                                            masked_lm_labels.unsqueeze(2))

        loss_mask = loss_mask.reshape([-1])
        masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
        loss = masked_lm_loss / loss_mask.sum()
        return loss
Esempio n. 11
0
def parallel_matmul(lm_output, logit_weights, parallel_output):
    hcg = fleet.get_hybrid_communicate_group()
    model_parallel_group = hcg.get_model_parallel_group()
    world_size = hcg.get_model_parallel_world_size()
    rank = hcg.get_model_parallel_rank()

    if world_size > 1:
        input_parallel = paddle.distributed.collective._c_identity(
            lm_output, group=model_parallel_group)

        logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True)

        if parallel_output:
            return logits

        return paddle.distributed.collective._c_concat(
            logits, group=model_parallel_group)
    else:
        logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
        return logits
    def build_model_optimizer(self):
        hcg = fleet.get_hybrid_communicate_group()
        word_size = hcg.get_model_parallel_world_size()
        mp_id = hcg.get_model_parallel_rank()
        dp_id = hcg.get_data_parallel_rank()
        rank_id = dist.get_rank()
        set_random_seed(1024, dp_id, rank_id)

        np_fc1 = np.random.random_sample((hidden_size, inner_size))
        np_fc2 = np.random.random_sample((inner_size, hidden_size))

        model_a = SimpleMPNet(vocab_size, hidden_size, inner_size, output_size,
                              np_fc1, np_fc2, mp_id)
        optimizer_a = self.build_optimizer(model_a)
        model_a = fleet.distributed_model(model_a)
        optimizer_a = fleet.distributed_optimizer(optimizer_a)

        model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size,
                              np_fc1, np_fc2)
        optimizer_b = self.build_optimizer(model_b)

        return model_a, optimizer_a, model_b, optimizer_b
Esempio n. 13
0
def do_train(args):
    paddle.set_device(args.device)
    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": args.mp_degree,
        "pp_degree": args.pp_degree
    }

    accumulate_steps = args.local_batch_size // args.micro_batch_size
    strategy.pipeline_configs = {
        "accumulate_steps": accumulate_steps,
        "micro_batch_size": args.micro_batch_size
    }

    fleet.init(is_collective=True, strategy=strategy)

    # obtain rank message of hybrid parallel
    hcg = fleet.get_hybrid_communicate_group()
    global_rank = hcg.get_global_rank()
    mp_rank = hcg.get_model_parallel_rank()
    pp_rank = hcg.get_stage_id()
    dp_rank = hcg.get_data_parallel_rank()
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    # seed control in hybrid parallel
    set_hyrbid_parallel_seed(args.seed, dp_rank, mp_rank, pp_rank)

    default_global_tokens_num = args.global_batch_size * args.max_seq_len

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define log writer
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.global_batch_size,
            args.use_pure_fp16, False, global_rank).lower())

    if os.path.exists(log_writer_path):
        import shutil
        shutil.rmtree(log_writer_path)

    log_writer = LogWriter(log_writer_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob

        model_config['num_partitions'] = args.mp_degree

        # MOE config
        initialize_model_and_expert_group(hcg)

        model_config['expert_mode'] = args.expert_mode
        model_config['hcg'] = hcg
        model_config['num_experts'] = args.num_experts
        model_config['top_k'] = args.top_k
        if args.expert_mode:
            model_config['gate'] = args.gate

        if args.pp_degree == 1:
            model_config["recompute_interval"] = 1 if args.use_recompute else 0
            model_config["recompute_partition"] = args.recompute_partition
            model_config["recompute_offload"] = args.recompute_offload
            if args.use_recompute and args.recompute_partition:
                raise Exception(
                    "when use_recompute is True, recompute_partition must be False in MoE."
                )

            model = GPTForPretraining(GPTModel(**model_config))
        else:
            model_config['topology'] = hcg.topology()
            model_config["recompute_interval"] = 1 if args.use_recompute else 0
            model = GPTForPretrainingPipe(**model_config)
    else:
        model = GPTForPretraining.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    # Create the critrion for the gpt model
    criterion = GPTPretrainingCriterion()

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps

    lr_scheduler = None

    if args.lr_decay_style == "none":
        lr_scheduler = None
    elif args.lr_decay_style == "cosine":
        lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
            max_lr=args.max_lr,
            min_lr=args.min_lr,
            warmup_step=warmup_step,
            decay_step=args.decay_steps)


# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)
        scaler._unscale = MethodType(unscale_method, scaler)
        model = paddle.amp.decorate(models=model,
                                    optimizers=None,
                                    level='O2',
                                    save_dtype='float32')

    opt_fused_tensors, decay_fused_tensors, reduce_fused_tensors, gate_fused_tensors, \
        expert_fusion_names = parameters_classify(model)
    decay_params = [p.name for p in decay_fused_tensors]

    clip = None
    if args.grad_clip > 0:
        is_expert_param_fun = lambda param: param.name in expert_fusion_names
        clip = moe.ClipGradByGlobalNorm(clip_norm=args.grad_clip, \
                                        is_expert_param_func = is_expert_param_fun, \
                                        moe_group = hcg.get_expert_parallel_group())

    optimizer = AdamW(
        learning_rate=lr_scheduler
        if lr_scheduler is not None else args.max_lr,
        beta1=args.adam_beta1,
        beta2=args.adam_beta2,
        epsilon=args.adam_epsilon,
        parameters=opt_fused_tensors,
        weight_decay=args.weight_decay,
        grad_clip=clip,
        apply_decay_param_fun=lambda x: x in decay_params,  #decay_params,
        multi_precision=args.use_pure_fp16)

    if paddle.distributed.get_world_size() > 1 and args.resume_dir is None:
        print(">> initialize....")
        initialize_mp_dp_parameters(model, hcg)

    #in order to restore reader.
    pass_num = 0
    file_id = 0
    start_epoch = 0
    args.resume_dir = None if len(args.resume_dir) <= 0 else args.resume_dir

    if args.resume_dir is not None:
        global_step, loss_scale, data_meta = load_checkpoint(
            args, model, optimizer, lr_scheduler, tokenizer, dp_rank, mp_rank,
            pp_rank)
        pass_num = data_meta["pass_num"]
        file_id = data_meta["file_id"]
        start_epoch = data_meta["start_epoch"]

    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(
            init_loss_scaling=loss_scale if args.
            resume_dir is not None else args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)
        scaler._unscale = MethodType(unscale_method, scaler)
        model, optimizer = paddle.amp.decorate(models=model,
                                               optimizers=optimizer,
                                               level='O2',
                                               save_dtype='float32')

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
        if os.path.exists(opt_path):
            opt_dict = paddle.load(opt_path)
            optimizer.set_state_dict(opt_dict)
        else:
            logger.warning("No optimizer checkpoint file found in %s." %
                           opt_path)

    global_step = 0 if args.resume_dir is None else global_step
    timers = get_timers()
    tic_train = time.time()
    for epoch in range(start_epoch, args.num_train_epochs):
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if (os.path.isfile(os.path.join(args.input_dir, f))
                and "npz_" not in str(f))
        ]
        files.sort()
        num_files = len(files)
        for f_id in range(file_id, num_files):
            data_file = files[f_id]
            train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                args,
                data_file,
                local_rank=local_rank,
                data_world_size=args.dp_degree,
                data_world_rank=dp_rank,
                eos_id=tokenizer.eos_token_id)

            # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
            # many times. and start a new random dataloader.
            valid_data_loader = valid_data_loader()
            test_data_loader = test_data_loader()

            for step, batch in enumerate(train_data_loader()):
                # to remove the train data that has been studyed.
                if step < global_step - pass_num: continue

                global_step += 1
                tokens, loss_mask, labels = batch

                loss_mask.stop_gradient = True
                labels.stop_gradient = True

                loss = 0.0
                for i in range(accumulate_steps):
                    start_index = i * args.micro_batch_size
                    end_index = start_index + args.micro_batch_size
                    timers('forward-compute').start()
                    with paddle.amp.auto_cast(
                            args.use_pure_fp16,
                            custom_black_list=[
                                "reduce_sum",
                                "c_softmax_with_cross_entropy",
                                "elementwise_div",
                            ],
                            level='O2'):
                        preds = model(tokens[start_index:end_index, :])
                        loss_mbs = criterion(
                            preds, labels[start_index:end_index, :],
                            loss_mask[start_index:end_index, :])
                    timers('forward-compute').stop()

                    if args.gate != "naive" and args.balance_loss_weight:
                        aux_loss_list = [
                            l.moe_mlp.gate.get_loss(clear=False)
                            for l in model.gpt.decoder.layers
                            if hasattr(l.moe_mlp, "gate")
                        ]
                        bal_loss = paddle.concat(aux_loss_list)
                        if bal_loss.dtype == paddle.float16:
                            bal_loss = paddle.cast(bal_loss,
                                                   dtype=paddle.float32)
                        bal_loss = bal_loss.mean()
                        loss_mbs += bal_loss * args.balance_loss_weight
                    loss_mbs = loss_mbs / accumulate_steps

                    timers('backward-compute').start()
                    if args.use_pure_fp16:
                        scaler.scale(loss_mbs).backward()
                    else:
                        loss_mbs.backward()
                    timers('backward-compute').stop()
                    loss = loss + loss_mbs

                timers('backward-params-all-reduce').start()
                all_reduce_parameters(gate_fused_tensors,
                                      hcg.get_expert_parallel_group())
                all_reduce_parameters(reduce_fused_tensors,
                                      hcg.get_data_parallel_group())
                timers('backward-params-all-reduce').stop()

                if args.use_pure_fp16:
                    scaler.minimize(optimizer, loss)
                else:
                    optimizer.step()
                learning_rate = optimizer.get_lr()
                if lr_scheduler is not None:
                    lr_scheduler.step()
                optimizer.clear_grad()

                if global_step % args.logging_freq == 0:
                    avg_loss = loss.numpy()
                    speed = args.logging_freq / (time.time() - tic_train)
                    if args.gate != "naive" and args.balance_loss_weight:
                        bal_loss = bal_loss.numpy()
                        avg_loss -= bal_loss
                    else:
                        bal_loss = -1
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, bal_loss: %.9f, speed: %.2f step/s, ips: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, avg_loss, bal_loss, speed,
                           speed * default_global_tokens_num, learning_rate))
                    log_writer.add_scalar("loss", float(loss), global_step)
                    log_writer.add_scalar("learning_rate", learning_rate,
                                          global_step)

                    tic_train = time.time()
                    timer_log(args.logging_freq)

                if (global_step % args.save_steps == 0
                        or global_step >= args.max_steps):
                    loss_scale = scaler._scale if args.use_pure_fp16 else None
                    save_checkpoint(args, global_step, model, optimizer,
                                    lr_scheduler, tokenizer, loss_scale,
                                    dp_rank, mp_rank, pp_rank, pass_num,
                                    file_id, epoch)
                    print(
                        "save checkpoint for step_{} successfully...loss_scale = {}"
                        .format(global_step, loss_scale))

                if global_step % args.eval_freq == 0:
                    # Since the valid data broardcast to all devices, we do evaluate on all device.
                    run_evaluate(args, valid_data_loader, model, criterion,
                                 args.eval_iters, log_writer, global_step,
                                 epoch, "valid")

                if global_step >= args.max_steps:
                    run_evaluate(args, test_data_loader, model, criterion,
                                 args.test_iters, log_writer, global_step,
                                 epoch, "test")
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

            # to record sum of the length of train_data_loader that has been read.
            pass_num += len(train_data_loader())
            del train_data_loader
    def test_pp_model(self):
        hcg = fleet.get_hybrid_communicate_group()
        word_size = hcg.get_model_parallel_world_size()
        dp_id = hcg.get_data_parallel_rank()
        pp_id = hcg.get_stage_id()
        rank_id = dist.get_rank()
        set_random_seed(1024, dp_id, rank_id)

        grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0)

        #construct model a
        model_a = AlexNet(10)
        scheduler_a = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2],
                                                         values=[0.001, 0.002],
                                                         verbose=True)
        optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
                                           grad_clip=grad_clip,
                                           parameters=model_a.parameters())

        scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5)

        param_len = len(model_a.parameters())
        parameters = []
        for param in model_a.parameters():
            parameters.append(param.numpy())

        # construct model b
        model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
        scheduler_b = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2],
                                                         values=[0.001, 0.002],
                                                         verbose=True)
        optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
                                           grad_clip=grad_clip,
                                           parameters=model_b.parameters())
        model_b = fleet.distributed_model(model_b)
        optimizer_b = fleet.distributed_optimizer(optimizer_b)
        scaler_b = paddle.amp.GradScaler(init_loss_scaling=2**5)
        scaler_b = fleet.distributed_scaler(scaler_b)

        for idx, param in enumerate(model_b.parameters()):
            param.set_value(parameters[idx + pp_id * (param_len // 2)])

        # construct reader
        train_reader = paddle.batch(paddle.dataset.mnist.train(),
                                    batch_size=batch_size,
                                    drop_last=True)

        for step_id, data in enumerate(train_reader()):
            x_data = np.array([x[0] for x in data]).astype('float32').reshape(
                batch_size, 1, 28, 28)
            y_data = np.array([x[1] for x in data
                               ]).astype('int64').reshape(batch_size, 1)
            img = paddle.to_tensor(x_data)
            label = paddle.to_tensor(y_data)
            img.stop_gradient = True
            label.stop_gradient = True

            if step_id >= 5:
                return True

            with paddle.amp.auto_cast():
                loss_a = model_a(img, label)
                scaler_a.scale(loss_a).backward()
                scaler_a.minimize(optimizer_a, loss_a)
                optimizer_a.clear_grad()
                scheduler_a.step()

            with paddle.amp.auto_cast():
                loss_b = model_b.train_batch([img, label],
                                             optimizer_b,
                                             scheduler_b,
                                             scaler=scaler_b)

            print("loss: ", loss_a.numpy(), loss_b.numpy())
            np.testing.assert_allclose(loss_a.numpy(),
                                       loss_b.numpy(),
                                       rtol=5e-5)
Esempio n. 15
0
def do_train(args):
    paddle.set_device(args.device)
    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": args.mp_degree,
        "pp_degree": args.pp_degree
    }

    strategy.pipeline_configs = {
        "accumulate_steps": args.local_batch_size // args.micro_batch_size,
        "micro_batch_size": args.micro_batch_size
    }

    fleet.init(is_collective=True, strategy=strategy)

    # obtain rank message of hybrid parallel
    hcg = fleet.get_hybrid_communicate_group()
    global_rank = hcg.get_global_rank()
    mp_rank = hcg.get_model_parallel_rank()
    pp_rank = hcg.get_stage_id()
    dp_rank = hcg.get_data_parallel_rank()
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    # seed control in hybrid parallel
    set_hyrbid_parallel_seed(args.seed, dp_rank, mp_rank, pp_rank)

    default_global_tokens_num = args.global_batch_size * args.max_seq_len

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define log writer
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.global_batch_size, args.use_amp,
            False, global_rank).lower())

    if os.path.exists(log_writer_path):
        import shutil
        shutil.rmtree(log_writer_path)

    log_writer = LogWriter(log_writer_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob

        model_config['num_partitions'] = args.mp_degree
        if args.pp_degree == 1:
            model = GPTForPretraining(GPTModel(**model_config))
        else:
            model_config['topology'] = hcg.topology()
            model_config["recompute_interval"] = 1 if args.use_recompute else 0
            model = GPTForPretrainingPipe(**model_config)
    else:
        model = GPTForPretraining.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    # Create the critrion for the gpt model
    criterion = GPTPretrainingCriterion()

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps

    lr_scheduler = None

    if args.lr_decay_style == "none":
        lr_scheduler = None
    elif args.lr_decay_style == "cosine":
        lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
            max_lr=args.max_lr,
            min_lr=args.min_lr,
            warmup_step=warmup_step,
            decay_step=args.decay_steps)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler
        if lr_scheduler is not None else args.max_lr,
        beta1=args.adam_beta1,
        beta2=args.adam_beta2,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=clip,
        apply_decay_param_fun=lambda x: x in decay_params)

    if paddle.distributed.get_world_size() > 1:
        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
        if os.path.exists(opt_path):
            opt_dict = paddle.load(opt_path)
            optimizer.set_state_dict(opt_dict)
        else:
            logger.warning("No optimizer checkpoint file found in %s." %
                           opt_path)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if (os.path.isfile(os.path.join(args.input_dir, f))
                and "npz_" not in str(f))
        ]
        files.sort()
        num_files = len(files)
        for f_id in range(num_files):
            data_file = files[f_id]
            train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                args,
                data_file,
                local_rank=local_rank,
                data_world_size=args.dp_degree,
                data_world_rank=dp_rank,
                eos_id=tokenizer.eos_token_id)
            # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
            # many times. and start a new random dataloader.
            valid_data_loader = valid_data_loader()
            test_data_loader = test_data_loader()

            for step, batch in enumerate(train_data_loader()):
                global_step += 1
                tokens, loss_mask, labels = batch

                loss_mask.stop_gradient = True
                labels.stop_gradient = True

                if args.pp_degree == 1:
                    with paddle.amp.auto_cast(
                            args.use_amp,
                            custom_white_list=[
                                "layer_norm", "softmax", "gelu"
                            ],
                            custom_black_list=[
                                "reduce_sum", "c_softmax_with_cross_entropy",
                                "c_embedding"
                            ]):
                        preds = model(tokens)
                        loss = criterion(preds, labels, loss_mask)

                    if args.use_amp:
                        scaler.scale(loss).backward()
                        scaler.minimize(optimizer, loss)
                    else:
                        loss.backward()
                        optimizer.step()

                    if lr_scheduler is not None:
                        lr_scheduler.step()
                    optimizer.clear_grad()

                else:
                    data = [tokens, (labels, loss_mask)]
                    with paddle.amp.auto_cast(
                            args.use_amp,
                            custom_white_list=[
                                "layer_norm", "softmax", "gelu"
                            ],
                            custom_black_list=[
                                "reduce_sum", "c_softmax_with_cross_entropy",
                                "c_embedding"
                            ]):
                        loss = model.train_batch(
                            data,
                            optimizer=optimizer,
                            lr_scheduler=lr_scheduler,
                            scaler=scaler if args.use_amp else None)

                if global_step % args.logging_freq == 0:
                    avg_loss = loss.numpy()
                    speed = args.logging_freq / (time.time() - tic_train)
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, speed: %.2f step/s, ips: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, avg_loss, speed, speed *
                           default_global_tokens_num, optimizer.get_lr()))
                    log_writer.add_scalar("loss", float(loss), global_step)
                    log_writer.add_scalar("learning_rate", optimizer.get_lr(),
                                          global_step)

                    tic_train = time.time()

                if args.check_accuracy:
                    if global_step >= args.max_steps:
                        return
                    else:
                        continue

                if global_step % args.eval_freq == 0:
                    # Since the valid data broardcast to all devices, we do evaluate on all device.
                    run_evaluate(args, valid_data_loader, model, criterion,
                                 args.eval_iters, log_writer, global_step,
                                 epoch, "valid")

                # only dp_rank = 0 save model
                if (global_step % args.save_steps == 0
                        or global_step >= args.max_steps) and dp_rank == 0:

                    model_to_save = model._layers if paddle.distributed.get_world_size(
                    ) > 1 else model
                    output_dir = os.path.join(args.output_dir,
                                              "step_%d" % global_step)
                    os.makedirs(output_dir, exist_ok=True)

                    logger.info("Save model to %s" % output_dir)

                    if args.pp_degree > 1:
                        model_to_save.save_state_dict(output_dir)
                        if mp_rank * pp_rank == 1:
                            tokenizer.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(
                                output_dir,
                                "model_state_mp_{:0>2d}_pp_{:0>2d}.pdopt".
                                format(mp_rank, pp_rank)))
                    else:
                        path = os.path.join(output_dir,
                                            'model_{:0>2d}'.format(mp_rank))
                        os.makedirs(path, exist_ok=True)
                        model_to_save.save_pretrained(path)

                        paddle.save(optimizer.state_dict(),
                                    os.path.join(path, "model_state.pdopt"))
                        tokenizer.save_pretrained(path)

                if global_step >= args.max_steps:
                    run_evaluate(args, test_data_loader, model, criterion,
                                 args.test_iters, log_writer, global_step,
                                 epoch, "test")
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

            del train_data_loader
Esempio n. 16
0
def train_mlp(model,
              sharding_stage,
              batch_size=100,
              use_pure_fp16=False,
              accumulate_grad=False,
              opt_group=False,
              save_model=False):
    if sharding_stage == "dp":
        hcg = fleet.get_hybrid_communicate_group()
        group = hcg.get_check_parallel_group()
    else:
        group = paddle.distributed.new_group([0, 1])
    if opt_group:
        optimizer = optimizer_setting(model=model,
                                      use_pure_fp16=use_pure_fp16,
                                      opt_group=opt_group)
    else:
        optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)

    if sharding_stage == 2:
        optimizer = ShardingOptimizerStage2(params=model.parameters(),
                                            optim=optimizer,
                                            group=group)

        model = ShardingStage2(model,
                               optimizer,
                               group=group,
                               buffer_max_size=2**21)
    else:
        optimizer = fleet.distributed_optimizer(optimizer)
        model = fleet.distributed_model(model)

    train_reader = paddle.batch(reader_decorator(),
                                batch_size=batch_size,
                                drop_last=True)

    train_loader = paddle.io.DataLoader.from_generator(capacity=32,
                                                       use_double_buffer=True,
                                                       iterable=True,
                                                       return_list=True,
                                                       use_multiprocess=True)
    train_loader.set_sample_list_generator(train_reader)

    if sharding_stage == 2:
        model.to(device="gpu")

    for eop in range(epoch):
        model.train()

        for batch_id, data in enumerate(train_loader()):
            img, label = data
            label.stop_gradient = True
            img.stop_gradient = True

            out = model(img)
            loss = paddle.nn.functional.cross_entropy(input=out, label=label)

            avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
            if batch_size == 20:
                avg_loss = avg_loss / 5
            avg_loss.backward()

            if not accumulate_grad:
                optimizer.step()
                optimizer.clear_grad()

        if accumulate_grad:
            optimizer.step()
            optimizer.clear_grad()

    if save_model:
        return model, optimizer
    return model.parameters()
Esempio n. 17
0
def do_train(args):
    paddle.set_device(args.device)
    nranks = paddle.distributed.get_world_size()
    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": args.mp_degree,
        "pp_degree": args.pp_degree,
        "sharding_degree": args.sharding_degree
    }

    accumulate_steps = args.local_batch_size // args.micro_batch_size
    strategy.pipeline_configs = {
        "accumulate_steps": accumulate_steps,
        "micro_batch_size": args.micro_batch_size
    }

    # set control in tensor parallel
    strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed}

    fleet.init(is_collective=True, strategy=strategy)

    # obtain rank message of hybrid parallel
    hcg = fleet.get_hybrid_communicate_group()
    global_rank = hcg.get_global_rank()
    mp_rank = hcg.get_model_parallel_rank()
    pp_rank = hcg.get_stage_id()
    dp_rank = hcg.get_data_parallel_rank()
    sharding_rank = hcg.get_sharding_parallel_rank()

    # sharding stage2/3 not support hybrid parallel
    if args.sharding_stage in [2, 3]:
        assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later"

    sharding_size = hcg.get_sharding_parallel_world_size()
    data_world_rank = dp_rank * sharding_size + sharding_rank
    data_world_size = args.dp_degree * args.sharding_degree
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    # seed control in hybrid parallel
    set_hyrbid_parallel_seed(args.seed, data_world_rank, mp_rank, pp_rank)

    default_global_tokens_num = args.global_batch_size * args.max_seq_len

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define log writer
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.global_batch_size, args.use_pure_fp16,
            False, global_rank).lower())

    if os.path.exists(log_writer_path):
        import shutil
        shutil.rmtree(log_writer_path)

    log_writer = LogWriter(log_writer_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob

        model_config['num_partitions'] = args.mp_degree
        model_config['use_recompute'] = args.use_recompute
        if args.pp_degree == 1:
            model = GPTForPretraining(GPTModel(**model_config))
        else:
            model_config['topology'] = hcg.topology()
            model = GPTForPretrainingPipe(**model_config)
    else:
        model = GPTForPretraining.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    # Create the critrion for the gpt model
    criterion = GPTPretrainingCriterion()

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps

    lr_scheduler = None

    if args.lr_decay_style == "none":
        lr_scheduler = None
    elif args.lr_decay_style == "cosine":
        lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
            max_lr=args.max_lr,
            min_lr=args.min_lr,
            warmup_step=warmup_step,
            decay_step=args.decay_steps)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]

    if args.sharding_stage == 1 and args.sharding_degree > 1:
        optimizer = DygraphShardingOptimizer(
            hcg=fleet.get_hybrid_communicate_group(),
            user_defined_strategy=strategy,
            params=model.parameters(),
            inner_optimizer_class=paddle.optimizer.AdamW,
            learning_rate=lr_scheduler
            if lr_scheduler is not None else args.max_lr,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            weight_decay=args.weight_decay,
            grad_clip=clip,
            apply_decay_param_fun=lambda x: x in decay_params)
    else:
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler
            if lr_scheduler is not None else args.max_lr,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            parameters=model.parameters(),
            weight_decay=args.weight_decay,
            grad_clip=clip,
            apply_decay_param_fun=lambda x: x in decay_params,
            # TODO: remove 'multi_precision' in definition of optimizer
            # and add it to 'paddle.amp.decorate'
            multi_precision=args.use_pure_fp16)

    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        # level O2 means converting the network to FP16
        if args.sharding_stage not in [2, 3]:
            scaler = fleet.distributed_scaler(scaler)
        model = paddle.amp.decorate(
            models=model, level='O2', save_dtype='float32')

    # wrap sharding stage2/3 and add collective group
    # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
    if args.sharding_stage in [2, 3]:
        scaler = scaler if args.use_pure_fp16 else None
        model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler,
                                                     args.sharding_offload)

    elif paddle.distributed.get_world_size() > 1:
        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " % args.model_name_or_path)
        opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
        if os.path.exists(opt_path):
            opt_dict = paddle.load(opt_path)
            optimizer.set_state_dict(opt_dict)
        else:
            logger.warning("No optimizer checkpoint file found in %s." %
                           opt_path)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        files = get_train_data_file(args)
        files.sort()
        num_files = len(files)
        for f_id in range(num_files):
            data_file = files[f_id]
            train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                args, [data_file],
                local_rank=local_rank,
                data_world_size=data_world_size,
                data_world_rank=data_world_rank,
                eos_id=tokenizer.eos_token_id)
            # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
            # many times. and start a new random dataloader.
            valid_data_loader = valid_data_loader()
            test_data_loader = test_data_loader()

            # time count
            train_reader_cost = 0.0
            train_run_cost = 0.0
            reader_start = time.time()
            for step, batch in enumerate(train_data_loader()):
                train_reader_cost += time.time() - reader_start
                train_start = time.time()

                global_step += 1
                tokens, loss_mask, position_ids, labels = batch

                loss_mask.stop_gradient = True
                labels.stop_gradient = True
                position_ids.stop_gradient = True

                if args.pp_degree == 1:
                    # In ParallelMode of DataParallel, 'no_sync' can be used for improving
                    # performance of model by gradient accumulation.
                    loss = 0.0
                    for i in range(accumulate_steps):
                        start_index = i * args.micro_batch_size
                        end_index = start_index + args.micro_batch_size
                        with paddle.amp.auto_cast(
                                args.use_pure_fp16,
                                custom_black_list=[
                                    "reduce_sum",
                                    "c_softmax_with_cross_entropy",
                                    "elementwise_div"
                                ],
                                level='O2'):
                            preds = model(
                                tokens[start_index:end_index, :],
                                position_ids[start_index:end_index, :])
                            loss_mbs = criterion(
                                preds, labels[start_index:end_index, :],
                                loss_mask[start_index:end_index, :])
                        loss_mbs = loss_mbs / accumulate_steps
                        if args.use_pure_fp16:
                            scaler.scale(loss_mbs).backward()
                        else:
                            loss_mbs.backward()
                        loss = loss + loss_mbs

                    if args.use_pure_fp16:
                        if args.sharding_stage in [2, 3]:
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            scaler.minimize(optimizer, loss)
                    else:
                        optimizer.step()

                    if lr_scheduler is not None:
                        lr_scheduler.step()

                    optimizer.clear_grad()

                else:
                    data = [(tokens, position_ids), (labels, loss_mask)]
                    with paddle.amp.auto_cast(
                            args.use_pure_fp16,
                            custom_black_list=[
                                "reduce_sum", "c_softmax_with_cross_entropy",
                                "elementwise_div"
                            ],
                            level='O2'):
                        loss = model.train_batch(
                            data,
                            optimizer=optimizer,
                            lr_scheduler=lr_scheduler,
                            scaler=scaler if args.use_pure_fp16 else None)

                # Sync for profile time, delete it may be a little faster
                paddle.device.cuda.synchronize()
                train_run_cost += time.time() - train_start
                # Profile for model benchmark
                profiler.add_profiler_step(args.profiler_options)

                if global_step % args.logging_freq == 0:
                    avg_loss = loss.numpy()
                    speed = args.logging_freq / (
                        train_reader_cost + train_run_cost)
                    avg_reader_cost = train_reader_cost / args.logging_freq

                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, speed: %.2f step/s, ips: %.0f tokens/s, ips_per_card: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, avg_loss, avg_reader_cost,
                           1. / speed, speed, speed * default_global_tokens_num,
                           speed * default_global_tokens_num / nranks,
                           optimizer.get_lr()))
                    log_writer.add_scalar("loss", float(loss), global_step)
                    log_writer.add_scalar("learning_rate",
                                          optimizer.get_lr(), global_step)

                    tic_train = time.time()
                    train_reader_cost = 0.0
                    train_run_cost = 0.0

                if args.check_accuracy:
                    if global_step >= args.max_steps:
                        return
                    else:
                        continue

                if global_step % args.eval_freq == 0:
                    # Since the valid data broardcast to all devices, we do evaluate on all device.
                    run_evaluate(args, valid_data_loader, model, criterion,
                                 args.eval_iters, log_writer, global_step,
                                 epoch, "valid")

                # TODO: 1. merge paramters while saving model. 2. ensure that the model is saved and loaded correctly
                # only dp_rank = 0 save model
                if (global_step % args.save_steps == 0 or
                        global_step >= args.max_steps) and dp_rank == 0:

                    model_to_save = model._layers if paddle.distributed.get_world_size(
                    ) > 1 and args.sharding_stage not in [2, 3] else model
                    output_dir = os.path.join(args.output_dir,
                                              "step_%d" % global_step)
                    os.makedirs(output_dir, exist_ok=True)

                    logger.info("Save model to %s" % output_dir)

                    if args.pp_degree > 1:
                        if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0:
                            tokenizer.save_pretrained(output_dir)
                        model_to_save.save_state_dict(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(
                                output_dir,
                                "model_state_mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}.pdopt".
                                format(mp_rank, sharding_rank, pp_rank)))
                    else:
                        if args.sharding_stage == 3:
                            # If parameter need to convert to cpu, please add convert2cpu=True
                            model_to_save.get_all_parameters(convert2cpu=False)
                        if mp_rank == 0 and sharding_rank == 0:
                            tokenizer.save_pretrained(output_dir)
                        model_to_save.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(
                                output_dir,
                                "model_state_mp_{:0>2d}_sharding_{:0>2d}.pdopt".
                                format(mp_rank, sharding_rank)))

                if global_step >= args.max_steps:
                    run_evaluate(args, test_data_loader, model, criterion,
                                 args.test_iters, log_writer, global_step,
                                 epoch, "test")
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

                reader_start = time.time()

            del train_data_loader
    def test_pp_model(self):
        hcg = fleet.get_hybrid_communicate_group()
        word_size = hcg.get_model_parallel_world_size()
        dp_id = hcg.get_data_parallel_rank()
        pp_id = hcg.get_stage_id()
        rank_id = dist.get_rank()
        set_random_seed(1024, dp_id, rank_id)

        #construct model a
        model_a = SimpleNet()
        scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
            boundaries=[2, 3, 4],
            values=[0.01, 0.02, 0.03, 0.04],
            verbose=True)
        optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
                                           parameters=model_a.parameters())

        model_b = SimpleNetPipe(topology=hcg.topology())

        scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
            boundaries=[2, 3, 4],
            values=[0.01, 0.02, 0.03, 0.04],
            verbose=True)
        optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
                                           parameters=model_b.parameters())
        model_b = fleet.distributed_model(model_b)
        optimizer_b = fleet.distributed_optimizer(optimizer_b)

        param_len = len(model_a.parameters())

        parameters = []
        for param in model_a.parameters():
            parameters.append(param.numpy())

        model_b_params = model_b.parameters()

        if pp_id == 0:
            model_b_params[0].set_value(parameters[2])
            model_b_params[1].set_value(parameters[0])

        else:
            model_b_params[0].set_value(parameters[2])
            model_b_params[1].set_value(parameters[1])

        for step in range(5):
            x1_data = np.random.randint(0, vocab_size, size=[batch_size, 1])
            x2_data = np.random.randint(0, vocab_size, size=[batch_size, 1])
            y1_data = np.random.randint(0, hidden_size, size=[batch_size, 1])

            x1 = paddle.to_tensor(x1_data)
            x2 = paddle.to_tensor(x2_data)
            y1 = paddle.to_tensor(y1_data)

            x1.stop_gradient = True
            x2.stop_gradient = True
            y1.stop_gradient = True

            loss_a = model_a(x1, x2, y1)
            loss_a.backward()

            optimizer_a.step()
            optimizer_a.clear_grad()
            scheduler_a.step()

            loss_b = model_b.train_batch([(x1, x2), (y1, )], optimizer_b,
                                         scheduler_b)

            print("loss", loss_a.numpy(), loss_b.numpy())
            np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())
Esempio n. 19
0
def do_train(args):
    paddle.set_device(args.device)

    worker_index = paddle.distributed.get_rank()
    worker_num = paddle.distributed.get_world_size()
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    if worker_num > 1:
        paddle.distributed.init_parallel_env()

    if args.dp_degree * args.sharding_degree == 1:
        args.dp_degree = worker_num
        args.sharding_degree = 1

    args_post_process(args, worker_num)

    logger.info('{:20}:{}'.format("paddle commit id", paddle.version.commit))
    for arg in vars(args):
        logger.info('{:20}:{}'.format(arg, getattr(args, arg)))

    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": 1,
        "pp_degree": 1,
        "sharding_degree": 1
    }

    fleet.init(is_collective=True, strategy=strategy)
    hcg = fleet.get_hybrid_communicate_group()

    # Create the random seed for the worker
    set_seed(args)

    assert args.dp_degree * args.sharding_degree == worker_num, \
        "The product of degree num should be equal to worker_num."

    # Create log write,
    log_writer = None
    if worker_index == 0:
        log_writer = LogWriter(os.path.join(args.output_dir, default_logdir()))

    # Define the input data in the static mode
    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    # load config in checkpoint
    global_step = 0
    consumed_samples = 0
    checkpoint_dir = os.path.join(args.output_dir, "model_last")
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f:
                step_config = yaml.load(f, Loader=yaml.FullLoader)
                assert step_config[
                    "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format(
                        checkpoint_dir)
                consumed_samples = step_config["consumed_samples"]
                global_step = step_config["global_step"]

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
        model = model_class(base_class(**model_config))
    else:
        model = model_class.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    criterion = criterion_class()

    if worker_index == 0:
        # log the model config and args
        model_config_json = json.dumps(model.get_model_config(),
                                       ensure_ascii=False,
                                       indent=2)
        log_writer.add_text("model_config", model_config_json)
        args_dict = {"paddle commit id": str(paddle.version.commit)}
        for arg in vars(args):
            args_dict[arg] = str(getattr(args, arg))
        log_writer.add_text("args", json.dumps(args_dict, indent=2))

    # Create the learning_rate sheduler and optimizer
    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    assert args.warmup_rate <= 1.0 and args.warmup_rate >= 0.0, "warmup_rate should be in [0, 1]"
    args.warmup_steps = args.warmup_rate * args.max_steps

    lr_scheduler = LinearAnnealingWithWarmupDecay(
        args.max_lr,
        args.min_lr,
        warmup_step=args.warmup_steps,
        decay_step=args.decay_steps,
        last_epoch=global_step)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.fluid.clip.GradientClipByGlobalNorm(
            clip_norm=args.grad_clip)

    decay_param = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    logger.info("Using paddle.optimizer.AdamW.")
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler
        if lr_scheduler is not None else args.max_lr,
        beta1=args.adam_beta1,
        beta2=args.adam_beta2,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=clip,
        apply_decay_param_fun=lambda x: x in decay_param,
        multi_precision=args.use_amp)

    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)
        model = paddle.amp.decorate(models=model,
                                    level='O2',
                                    save_dtype='float32')

    if paddle.distributed.get_world_size() > 1:
        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    data_file = get_train_data_file(args)

    train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
        args,
        data_file,
        tokenizer,
        data_world_size=worker_num,
        data_world_rank=worker_index,
        max_seq_len=args.max_seq_len,
        current_step=global_step)

    # load checkpoint vars
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            logger.info("Try to load checkpoint from %s " % checkpoint_dir)
            opt_path = os.path.join(checkpoint_dir, "model_state.pdopt")
            params_path = os.path.join(checkpoint_dir, "model_state.pdparams")

            if os.path.exists(opt_path):
                opt_dict = paddle.load(opt_path)
                optimizer.set_state_dict(opt_dict)
                model_dict = paddle.load(params_path)
                model.set_state_dict(model_dict)
            else:
                logger.warning("No optimizer checkpoint file found in %s." %
                               opt_path)
            logger.info(
                "Checkpoint loaded from global step: {}".format(global_step))

    loss_global = {
        "loss": paddle.to_tensor(0.0),
        "lm_loss": paddle.to_tensor(0.0),
        "sop_loss": paddle.to_tensor(0.0),
    }
    tic_train = time.time()
    while True:
        # If not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        # time count
        train_reader_cost = 0.0
        train_run_cost = 0.0
        reader_start = time.time()

        for step, batch in enumerate(train_data_loader()):
            train_reader_cost += time.time() - reader_start
            train_start = time.time()

            # 0. input_ids,
            # 1. segment_ids,
            # 2. input_mask,
            # 3. masked_lm_positions,
            # 4. masked_lm_labels,
            # 5. next_sentence_labels

            input_ids, segment_ids, input_mask, masked_lm_positions, \
            masked_lm_labels, next_sentence_labels = batch

            with paddle.amp.auto_cast(args.use_amp,
                                      custom_black_list=[
                                          "reduce_sum",
                                          "c_softmax_with_cross_entropy",
                                          "elementwise_div"
                                      ],
                                      level='O2'):

                # Create the model for the ernie pretrain
                prediction_scores, seq_relationship_score = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    position_ids=None,
                    attention_mask=input_mask,
                    masked_positions=masked_lm_positions)

                lm_loss, sop_loss = criterion(prediction_scores,
                                              seq_relationship_score,
                                              masked_lm_labels,
                                              next_sentence_labels)
                loss = lm_loss + sop_loss

            if args.use_amp:
                scaler.scale(loss).backward()
                scaler.minimize(optimizer, loss)
            else:
                loss.backward()
                optimizer.step()

            optimizer.clear_grad()
            train_run_cost += time.time() - train_start

            # Skip for accumulate_steps in global step
            if (step + 1) % args.accumulate_steps != 0:
                continue

            global_step += 1

            loss_global["loss"] += loss.detach()
            loss_global["lm_loss"] += lm_loss.detach()
            loss_global["sop_loss"] += sop_loss.detach()

            if global_step % args.logging_freq == 0:
                log_info_dict = dict()
                log_info_dict["global_step"] = global_step
                for k, v in loss_global.items():
                    log_info_dict[k] = all_gather(v) / args.logging_freq
                    v.subtract_(v)
                if worker_index == 0:
                    speed = args.logging_freq / (time.time() - tic_train)
                    log_info_dict["learning_rate"] = lr_scheduler.get_lr()
                    log_info_dict["steps_per_second"] = speed
                    log_info_dict[
                        "samples_per_second"] = speed * args.global_batch_size

                    for k, v in log_info_dict.items():
                        log_writer.add_scalar("train/%s" % k, v, global_step)

                    common_loginfo = "global step %d, loss: %.9f, lm_loss: %.6f, sop_loss: %.6f, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
                        global_step, log_info_dict["loss"],
                        log_info_dict["lm_loss"], log_info_dict["sop_loss"],
                        speed, log_info_dict["samples_per_second"],
                        log_info_dict["learning_rate"])

                    addition_info = ""
                    if args.use_amp:
                        amp_info = {
                            "loss_scaling": scaler._scale.item(),
                            "incr_count": scaler._incr_count,
                            "decr_count": scaler._decr_count
                        }
                        addition_info = ", ".join("%s: %d" % (k, v)
                                                  for k, v in amp_info.items())
                        addition_info = " " + addition_info
                        for k, v in amp_info.items():
                            log_writer.add_scalar("amp/%s" % k, v, global_step)

                    logger.info(common_loginfo + addition_info)

                tic_train = time.time()

            if lr_scheduler is not None:
                lr_scheduler.step()

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation

                run_evaluate(valid_data_loader,
                             model,
                             criterion,
                             args.eval_iters,
                             log_writer,
                             global_step,
                             args,
                             task_name="valid")
                tic_train = time.time()

            def save_ckpt(output_dir, model, tokenizer, args, global_step):
                step_config = {
                    "model_name": args.model_name_or_path,
                    "global_step": global_step,
                    "global_batch_size": args.global_batch_size,
                    "consumed_samples": global_step * args.global_batch_size,
                }

                logger.debug("saving models to {}".format(output_dir))
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model

                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                paddle.save(optimizer.state_dict(),
                            os.path.join(output_dir, "model_state.pdopt"))

                with open(os.path.join(output_dir, "config.yml"), "w") as f:
                    yaml.dump(step_config,
                              f,
                              encoding='utf-8',
                              allow_unicode=True)

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if worker_index == 0:
                    save_ckpt(output_dir, model, tokenizer, args, global_step)

                if worker_num > 1:
                    paddle.distributed.barrier()
                tic_train = time.time()

            if global_step % args.checkpoint_steps == 0:
                output_dir = os.path.join(args.output_dir, "model_last")
                if worker_index == 0:
                    if not os.path.exists(output_dir):
                        os.mkdir(output_dir)
                    output_dir_bak = os.path.join(args.output_dir,
                                                  "model_last_bak")
                    if os.path.exists(output_dir):
                        if os.path.exists(output_dir_bak):
                            shutil.rmtree(output_dir_bak)
                        shutil.move(output_dir, output_dir_bak)
                        os.mkdir(output_dir)
                    save_ckpt(output_dir, model, tokenizer, args, global_step)

                if worker_num > 1:
                    paddle.distributed.barrier()

            if global_step >= args.max_steps:
                run_evaluate(test_data_loader,
                             model,
                             criterion,
                             args.test_iters,
                             log_writer,
                             global_step,
                             args,
                             task_name="test")
                del train_data_loader
                return