def test_checkpoint_recompute_checkpoint(recompute_checkpoint):
    """
    If a checkpoint is saved with `recompute_checkpoint_every_layer`
      then we should be able to restore the checkpoint in a new run
      that doesn't use `recompute_checkpoint_every_layer` and vice-verse.
    """
    args = """
    --config unit_test
    """.split()
    config1 = BertConfig(**(vars(parse_bert_args(args))))
    config1.recompute_checkpoint_every_layer = recompute_checkpoint
    model1 = PipelinedBertForPretraining(config1).parallelize()

    with tempfile.TemporaryDirectory() as dir:
        # Save checkpoint
        config1.checkpoint_output_dir = dir
        save_checkpoint(config1, model1, 0)

        # New model with opposite `recompute_checkpoint` to model1
        config2 = BertConfig(**(vars(parse_bert_args(args))))
        config2.recompute_checkpoint_every_layer = not recompute_checkpoint
        model2 = PipelinedBertForPretraining.from_pretrained(os.path.join(dir, "step_0"), config=config2).parallelize()

        # Models should now have the same weights
        for name, tensor1 in model1.state_dict().items():
            tensor2 = model2.state_dict()[name]
            assert torch.allclose(tensor1, tensor2)
def test_checkpoint_embedding_serialization(embedding_serialization_factor):
    """
    If a checkpoint is saved with embedding_serialization_factor
      then we should be able to restore the checkpoint in a new run
      where embedding_serialization_factor isn't used.
    The reverse should also hold.
    """
    args = """
    --config unit_test
    """.split()
    config1 = BertConfig(**(vars(parse_bert_args(args))))
    config1.embedding_serialization_factor = embedding_serialization_factor
    model1 = PipelinedBertForPretraining(config1).parallelize()

    with tempfile.TemporaryDirectory() as dir:
        # Save checkpoint
        config1.checkpoint_output_dir = dir
        save_checkpoint(config1, model1, 0)

        # New model with opposite embedding_serialization to model1
        config2 = BertConfig(**(vars(parse_bert_args(args))))
        config2.embedding_serialization_factor = 5 if embedding_serialization_factor == 1 else 1
        model2 = PipelinedBertForPretraining.from_pretrained(os.path.join(dir, "step_0"), config=config2).parallelize()

        assert model2.config.embedding_serialization_factor == config2.embedding_serialization_factor

        # Models should now have the same weights
        for name, tensor1 in model1.state_dict().items():
            tensor2 = model2.state_dict()[name]
            assert torch.allclose(tensor1, tensor2)
def test_checkpoint_embedding_serialization_qa(embedding_serialization_factor):
    """
    If a checkpoint is saved with embedding_serialization_factor
      then we should be able to restore the checkpoint in a new run
      where embedding_serialization_factor isn't used.
    The reverse should also hold.
    For PipelinedBertForQuestionAnswering we will need to call `deparallelize`
    before checkpointing.
    """
    args = """
    --config unit_test
    """.split()
    config = BertConfig(**(vars(parse_bert_args(args))))
    config.embedding_serialization_factor = embedding_serialization_factor
    model1 = PipelinedBertForQuestionAnswering(config).parallelize()

    with tempfile.TemporaryDirectory() as dir:
        # Save checkpoint
        config.checkpoint_output_dir = dir
        model1.deparallelize()
        save_checkpoint(config, model1, 0)

        # Load the checkpoint, but don't call parallelize
        model2 = PipelinedBertForQuestionAnswering.from_pretrained(os.path.join(dir, "step_0"))

        # Models should have the same weights
        for name, tensor1 in model1.state_dict().items():
            tensor2 = model2.state_dict()[name]
            assert torch.allclose(tensor1, tensor2)
def test_checkpoint_save_restore(recompute_checkpoint, embedding_serialization_factor):
    """
    Test that saving and restoring checkpoints works. Also test checkpointing
    with recomputation checkpoints and embedding serialization.
    """
    args = """
    --config unit_test
    """.split()
    config = BertConfig(**(vars(parse_bert_args(args))))
    config.recompute_checkpoint_every_layer = recompute_checkpoint
    config.embedding_serialization_factor = embedding_serialization_factor
    model1 = PipelinedBertForPretraining(config).parallelize()
    model2 = PipelinedBertForPretraining(config).parallelize()

    # The two models should have different initial weights
    for name, tensor1 in model1.state_dict().items():
        tensor2 = model2.state_dict()[name]
        if (tensor1.dtype is not torch.int64) and ("LayerNorm" not in name) and ("bias" not in name):
            assert not torch.allclose(tensor1, tensor2)

    # Save and restore checkpoint
    with tempfile.TemporaryDirectory() as dir:
        config.checkpoint_output_dir = dir
        # No checkpoints should exist yet
        assert not checkpoints_exist(config.checkpoint_output_dir)

        save_checkpoint(config, model1, 0)

        # Checkpoint should now exist
        assert checkpoints_exist(config.checkpoint_output_dir)

        # Restore from checkpoint
        model2 = PipelinedBertForPretraining.from_pretrained(os.path.join(dir, "step_0"), config=config)

        # Models should now have the same weights
        for name, tensor1 in model1.state_dict().items():
            tensor2 = model2.state_dict()[name]
            assert torch.allclose(tensor1, tensor2)
def wgan_gp_run(train_window=52,
                test_size=52,
                horizon=12,
                batch_size=30,
                epochs=1,
                d_iterations=3,
                time_series=True,
                BERT=True,
                load=None,
                cuda_no='cuda:0'):

    if torch.cuda.is_available():
        device = torch.device(cuda_no)
        torch.cuda.set_device(device=device)
        FT = torch.cuda.FloatTensor
        use_cuda = True
    else:
        device = torch.device('cpu')
        use_cuda = False
        FT = torch.FloatTensor

    dataloader = get_dataloader(train_window, test_size, batch_size, horizon)

    input_size = 1
    if time_series:
        input_size += 1
    if BERT:
        input_size += 768

    G = LSTMGenerator(input_size=input_size,
                      hidden_layer_size=300,
                      num_layers=2,
                      output_size=1,
                      horizon=horizon,
                      device=device)

    D = ConvDiscriminator(
        input_channels=input_size,
        output_channels=1,
    )

    d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
    g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

    if use_cuda:
        G.cuda(device=device)
        D.cuda(device=device)

    if not load == None:
        load_epoch = load[1]
        load_checkpoint(G,
                        D,
                        g_optimizer,
                        d_optimizer,
                        date=load[0],
                        epoch=load_epoch,
                        name=load[2])
    else:
        load_epoch = 0

    if time_series:
        cat = cat_with_seq_with_enc if BERT else cat_with_seq_no_enc
    else:
        cat = cat_no_seq_with_enc if BERT else cat_no_seq_no_enc

    for epoch in range(load_epoch + 1, epochs + load_epoch + 1):
        i = 0
        j = 0
        for di in range(d_iterations):
            for seq, encoded, labels in dataloader:
                if (i == 1) and (di == 0):
                    print('Currently at Epoch: {}, Discriminator error: {}'.
                          format(epoch, d_error))
                seq = torch.from_numpy(seq).type(FT).unsqueeze(2)
                bz = seq.size(0)
                G.clear_hidden(bz)
                labels = torch.from_numpy(labels).type(FT).unsqueeze(2)
                encoded = torch.from_numpy(encoded).type(FT)

                with torch.no_grad():
                    generated_labels = pad(
                        G(cat(noise((bz, train_window, 1), FT), seq, encoded)),
                        train_window, horizon)

                GP = gradient_penalty(D,
                                      generated_labels,
                                      labels,
                                      encoded,
                                      seq,
                                      Lambda=10,
                                      device=device,
                                      cat=cat)
                d_error = train_discriminator(
                    D, d_optimizer,
                    cat(labels, seq, encoded).to(device),
                    cat(generated_labels, seq, encoded).to(device), GP)
                del labels, generated_labels, encoded, GP
                i += 1

        print('Currently at Epoch: {}, Discriminator error: {}'.format(
            epoch, d_error))

        for seq, encoded, labels in dataloader:
            if j == 1:
                print('Currently at Epoch: {},  Generator error: {}'.format(
                    epoch, g_error))
            seq = torch.from_numpy(seq).type(FT).unsqueeze(2)
            bz = seq.size(0)
            G.clear_hidden(bz)
            labels = torch.from_numpy(labels).type(FT).unsqueeze(2)
            encoded = torch.from_numpy(encoded).type(FT)

            generated_labels = pad(
                G(cat(noise((bz, train_window, 1), FT), seq, encoded)),
                train_window, horizon)

            g_error = train_generator(
                D, g_optimizer,
                cat(generated_labels, seq, encoded).to(device))

            del generated_labels, encoded, seq
            j += 1

        print('Currently at Epoch: {},  Generator error: {}'.format(
            epoch, g_error))

        if (epoch) % 5000 == 0:
            save_checkpoint(
                G, D, g_optimizer, d_optimizer, epoch,
                'tw-{}_hz-{}_bs-{}_di-{}_t-{}_B-{}_id-{}_hlr_snp'.format(
                    train_window, horizon, batch_size, d_iterations,
                    time_series, BERT, D.identifier))
Exemple #6
0
def do_train(args):
    paddle.set_device(args.device)
    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": args.mp_degree,
        "pp_degree": args.pp_degree
    }

    accumulate_steps = args.local_batch_size // args.micro_batch_size
    strategy.pipeline_configs = {
        "accumulate_steps": accumulate_steps,
        "micro_batch_size": args.micro_batch_size
    }

    fleet.init(is_collective=True, strategy=strategy)

    # obtain rank message of hybrid parallel
    hcg = fleet.get_hybrid_communicate_group()
    global_rank = hcg.get_global_rank()
    mp_rank = hcg.get_model_parallel_rank()
    pp_rank = hcg.get_stage_id()
    dp_rank = hcg.get_data_parallel_rank()
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    # seed control in hybrid parallel
    set_hyrbid_parallel_seed(args.seed, dp_rank, mp_rank, pp_rank)

    default_global_tokens_num = args.global_batch_size * args.max_seq_len

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define log writer
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.global_batch_size,
            args.use_pure_fp16, False, global_rank).lower())

    if os.path.exists(log_writer_path):
        import shutil
        shutil.rmtree(log_writer_path)

    log_writer = LogWriter(log_writer_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob

        model_config['num_partitions'] = args.mp_degree

        # MOE config
        initialize_model_and_expert_group(hcg)

        model_config['expert_mode'] = args.expert_mode
        model_config['hcg'] = hcg
        model_config['num_experts'] = args.num_experts
        model_config['top_k'] = args.top_k
        if args.expert_mode:
            model_config['gate'] = args.gate

        if args.pp_degree == 1:
            model_config["recompute_interval"] = 1 if args.use_recompute else 0
            model_config["recompute_partition"] = args.recompute_partition
            model_config["recompute_offload"] = args.recompute_offload
            if args.use_recompute and args.recompute_partition:
                raise Exception(
                    "when use_recompute is True, recompute_partition must be False in MoE."
                )

            model = GPTForPretraining(GPTModel(**model_config))
        else:
            model_config['topology'] = hcg.topology()
            model_config["recompute_interval"] = 1 if args.use_recompute else 0
            model = GPTForPretrainingPipe(**model_config)
    else:
        model = GPTForPretraining.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    # Create the critrion for the gpt model
    criterion = GPTPretrainingCriterion()

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps

    lr_scheduler = None

    if args.lr_decay_style == "none":
        lr_scheduler = None
    elif args.lr_decay_style == "cosine":
        lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
            max_lr=args.max_lr,
            min_lr=args.min_lr,
            warmup_step=warmup_step,
            decay_step=args.decay_steps)


# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)
        scaler._unscale = MethodType(unscale_method, scaler)
        model = paddle.amp.decorate(models=model,
                                    optimizers=None,
                                    level='O2',
                                    save_dtype='float32')

    opt_fused_tensors, decay_fused_tensors, reduce_fused_tensors, gate_fused_tensors, \
        expert_fusion_names = parameters_classify(model)
    decay_params = [p.name for p in decay_fused_tensors]

    clip = None
    if args.grad_clip > 0:
        is_expert_param_fun = lambda param: param.name in expert_fusion_names
        clip = moe.ClipGradByGlobalNorm(clip_norm=args.grad_clip, \
                                        is_expert_param_func = is_expert_param_fun, \
                                        moe_group = hcg.get_expert_parallel_group())

    optimizer = AdamW(
        learning_rate=lr_scheduler
        if lr_scheduler is not None else args.max_lr,
        beta1=args.adam_beta1,
        beta2=args.adam_beta2,
        epsilon=args.adam_epsilon,
        parameters=opt_fused_tensors,
        weight_decay=args.weight_decay,
        grad_clip=clip,
        apply_decay_param_fun=lambda x: x in decay_params,  #decay_params,
        multi_precision=args.use_pure_fp16)

    if paddle.distributed.get_world_size() > 1 and args.resume_dir is None:
        print(">> initialize....")
        initialize_mp_dp_parameters(model, hcg)

    #in order to restore reader.
    pass_num = 0
    file_id = 0
    start_epoch = 0
    args.resume_dir = None if len(args.resume_dir) <= 0 else args.resume_dir

    if args.resume_dir is not None:
        global_step, loss_scale, data_meta = load_checkpoint(
            args, model, optimizer, lr_scheduler, tokenizer, dp_rank, mp_rank,
            pp_rank)
        pass_num = data_meta["pass_num"]
        file_id = data_meta["file_id"]
        start_epoch = data_meta["start_epoch"]

    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(
            init_loss_scaling=loss_scale if args.
            resume_dir is not None else args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)
        scaler._unscale = MethodType(unscale_method, scaler)
        model, optimizer = paddle.amp.decorate(models=model,
                                               optimizers=optimizer,
                                               level='O2',
                                               save_dtype='float32')

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
        if os.path.exists(opt_path):
            opt_dict = paddle.load(opt_path)
            optimizer.set_state_dict(opt_dict)
        else:
            logger.warning("No optimizer checkpoint file found in %s." %
                           opt_path)

    global_step = 0 if args.resume_dir is None else global_step
    timers = get_timers()
    tic_train = time.time()
    for epoch in range(start_epoch, args.num_train_epochs):
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if (os.path.isfile(os.path.join(args.input_dir, f))
                and "npz_" not in str(f))
        ]
        files.sort()
        num_files = len(files)
        for f_id in range(file_id, num_files):
            data_file = files[f_id]
            train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                args,
                data_file,
                local_rank=local_rank,
                data_world_size=args.dp_degree,
                data_world_rank=dp_rank,
                eos_id=tokenizer.eos_token_id)

            # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
            # many times. and start a new random dataloader.
            valid_data_loader = valid_data_loader()
            test_data_loader = test_data_loader()

            for step, batch in enumerate(train_data_loader()):
                # to remove the train data that has been studyed.
                if step < global_step - pass_num: continue

                global_step += 1
                tokens, loss_mask, labels = batch

                loss_mask.stop_gradient = True
                labels.stop_gradient = True

                loss = 0.0
                for i in range(accumulate_steps):
                    start_index = i * args.micro_batch_size
                    end_index = start_index + args.micro_batch_size
                    timers('forward-compute').start()
                    with paddle.amp.auto_cast(
                            args.use_pure_fp16,
                            custom_black_list=[
                                "reduce_sum",
                                "c_softmax_with_cross_entropy",
                                "elementwise_div",
                            ],
                            level='O2'):
                        preds = model(tokens[start_index:end_index, :])
                        loss_mbs = criterion(
                            preds, labels[start_index:end_index, :],
                            loss_mask[start_index:end_index, :])
                    timers('forward-compute').stop()

                    if args.gate != "naive" and args.balance_loss_weight:
                        aux_loss_list = [
                            l.moe_mlp.gate.get_loss(clear=False)
                            for l in model.gpt.decoder.layers
                            if hasattr(l.moe_mlp, "gate")
                        ]
                        bal_loss = paddle.concat(aux_loss_list)
                        if bal_loss.dtype == paddle.float16:
                            bal_loss = paddle.cast(bal_loss,
                                                   dtype=paddle.float32)
                        bal_loss = bal_loss.mean()
                        loss_mbs += bal_loss * args.balance_loss_weight
                    loss_mbs = loss_mbs / accumulate_steps

                    timers('backward-compute').start()
                    if args.use_pure_fp16:
                        scaler.scale(loss_mbs).backward()
                    else:
                        loss_mbs.backward()
                    timers('backward-compute').stop()
                    loss = loss + loss_mbs

                timers('backward-params-all-reduce').start()
                all_reduce_parameters(gate_fused_tensors,
                                      hcg.get_expert_parallel_group())
                all_reduce_parameters(reduce_fused_tensors,
                                      hcg.get_data_parallel_group())
                timers('backward-params-all-reduce').stop()

                if args.use_pure_fp16:
                    scaler.minimize(optimizer, loss)
                else:
                    optimizer.step()
                learning_rate = optimizer.get_lr()
                if lr_scheduler is not None:
                    lr_scheduler.step()
                optimizer.clear_grad()

                if global_step % args.logging_freq == 0:
                    avg_loss = loss.numpy()
                    speed = args.logging_freq / (time.time() - tic_train)
                    if args.gate != "naive" and args.balance_loss_weight:
                        bal_loss = bal_loss.numpy()
                        avg_loss -= bal_loss
                    else:
                        bal_loss = -1
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, bal_loss: %.9f, speed: %.2f step/s, ips: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, avg_loss, bal_loss, speed,
                           speed * default_global_tokens_num, learning_rate))
                    log_writer.add_scalar("loss", float(loss), global_step)
                    log_writer.add_scalar("learning_rate", learning_rate,
                                          global_step)

                    tic_train = time.time()
                    timer_log(args.logging_freq)

                if (global_step % args.save_steps == 0
                        or global_step >= args.max_steps):
                    loss_scale = scaler._scale if args.use_pure_fp16 else None
                    save_checkpoint(args, global_step, model, optimizer,
                                    lr_scheduler, tokenizer, loss_scale,
                                    dp_rank, mp_rank, pp_rank, pass_num,
                                    file_id, epoch)
                    print(
                        "save checkpoint for step_{} successfully...loss_scale = {}"
                        .format(global_step, loss_scale))

                if global_step % args.eval_freq == 0:
                    # Since the valid data broardcast to all devices, we do evaluate on all device.
                    run_evaluate(args, valid_data_loader, model, criterion,
                                 args.eval_iters, log_writer, global_step,
                                 epoch, "valid")

                if global_step >= args.max_steps:
                    run_evaluate(args, test_data_loader, model, criterion,
                                 args.test_iters, log_writer, global_step,
                                 epoch, "test")
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

            # to record sum of the length of train_data_loader that has been read.
            pass_num += len(train_data_loader())
            del train_data_loader
    start_compile = time.perf_counter()
    datum = get_generated_datum(config)
    poptorch_model.compile(*datum)
    duration_compilation = time.perf_counter() - start_compile
    logger(f"Compiled/Loaded model in {duration_compilation} secs")
    logger("-----------------------------------------------------------")

    # Save model and end here if compile only mode is enabled
    if config.compile_only:
        logger(
            "Model successfully compiled. Exiting now as '--compile-only' argument was passed."
        )
        sys.exit(0)

    # Checkpoint model at start of run
    save_checkpoint(config, model, steps_finished, optimizer)

    # Training loop
    logger("--------------------- Training Started --------------------")
    factor = config.gradient_accumulation * config.batches_per_step
    start_train = time.perf_counter()
    train_iterator = tqdm(
        range(steps_finished, config.training_steps),
        desc="Training",
        disable=config.disable_progress_bar
        or (config.use_popdist and not (config.popdist_rank == 0)))
    for step in train_iterator:
        start_step = time.perf_counter()
        outputs = poptorch_model(*next(loader))
        scheduler.step()
        poptorch_model.setOptimizer(optimizer)