Exemplo n.º 1
0
def wrap_sharding_2_3(model, optimizer, scaler, sharding_offload):
    group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group()
    level = "p_g_os" if args.sharding_stage == 3 else "os_g"
    return group_sharded_parallel(model=model,
                                  optimizer=optimizer,
                                  level=level,
                                  scaler=scaler,
                                  group=group,
                                  offload=sharding_offload)
def train_mlp(model, shard_level, use_pure_fp16, output_dir):
    group = paddle.distributed.new_group([0, 1])

    optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
    model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
    scaler = paddle.amp.GradScaler(init_loss_scaling=32768)

    model, optimizer, scaler = group_sharded_parallel(model=model,
                                                      optimizer=optimizer,
                                                      level=shard_level,
                                                      scaler=scaler)

    train_reader = paddle.batch(reader_decorator(),
                                batch_size=batch_size,
                                drop_last=True)

    train_loader = paddle.io.DataLoader.from_generator(capacity=32,
                                                       use_double_buffer=True,
                                                       iterable=True,
                                                       return_list=True,
                                                       use_multiprocess=True)
    train_loader.set_sample_list_generator(train_reader)

    for eop in range(epoch):
        model.train()
        for batch_id, data in enumerate(train_loader()):
            img, label = data
            label.stop_gradient = True
            img.stop_gradient = True
            with paddle.amp.auto_cast(True, level='O2'):
                out = model(img)
                loss = paddle.nn.functional.cross_entropy(input=out,
                                                          label=label)
            avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))

            if not use_pure_fp16:
                avg_loss.backward()
                optimizer.step()
            else:
                scaler.scale(avg_loss).backward()
                scaler.step(optimizer)
                scaler.update()

            optimizer.clear_grad()

    save_group_sharded_model(model, output=output_dir, optimizer=optimizer)
    return model.parameters()