示例#1
0
class DenseLm12kWide41BAdam16x16(DenseLm128B16x16):
    """41B params LM model with 2D split and ADAM optimizer on v3-512."""

    # Each layer has 1.6875B parameters.
    SEQUENCE_LENGTH = 2048
    NUM_DEVICES_PER_SPLIT = 512
    BATCH_DIM_PER_DEVICE = 0.5  # Total batch size 256
    DEVICE_MESH_SHAPE = [16, 32]
    DEVICE_MESH = gshard_utils.GetNonPod2dMesh(DEVICE_MESH_SHAPE, [16, 16, 2])
    NUM_TRANSFORMER_LAYERS = 24
    HIDDEN_DIM = 48 * 1024
    MODEL_DIM = 12 * 1024
    NUM_HEADS = 96
    ATTENTION_KEY_VALUE_DIM = 128
    GATED_GELU = False
    POSITIONAL_EMBEDDING = True
    NUM_MICRO_BATCHES = 1

    def Task(self):
        p = super().Task()
        p.train.optimizer = ShardedAdam.Params().Set(
            beta1=0.9,
            beta2=0.999,
            epsilon=1e-6,
            num_micro_batches=self.NUM_MICRO_BATCHES)
        return p
示例#2
0
class DenseLm12kWide41BAdam8x8(DenseLm12kWide41BAdam16x16):
    # IF OOM, try 0.25 BATCH_DIM_PER_DEVICE and 8 NUM_MICRO_BATCHES
    BATCH_DIM_PER_DEVICE = 0.5  # Total micro-batch size 64
    NUM_MICRO_BATCHES = 4  # Total batch size 256
    NUM_DEVICES_PER_SPLIT = 128
    DEVICE_MESH_SHAPE = [8, 16]
    DEVICE_MESH = gshard_utils.GetNonPod2dMesh(DEVICE_MESH_SHAPE, [8, 8, 2])
示例#3
0
class DenseLm128B16x16(DenseLm128B8x8):
    """128B params LM model with 2D split on v3-512."""
    SEQUENCE_LENGTH = 1024
    NUM_DEVICES_PER_SPLIT = 512
    BATCH_DIM_PER_DEVICE = 0.25  # Total batch size 128
    DEVICE_MESH_SHAPE = [16, 32]
    DEVICE_MESH = gshard_utils.GetNonPod2dMesh(DEVICE_MESH_SHAPE, [16, 16, 2])
class DenseLm12kWide162BAdam16x16(DenseLm12kWide41BAdam16x16):
  """162B params LM model with 2D split and ADAM optimizer on v3-512."""

  BATCH_DIM_PER_DEVICE = 0.125  # Total batch size 64
  NUM_TRANSFORMER_LAYERS = 96
  DEVICE_MESH_SHAPE = [16, 32]
  DEVICE_MESH = gshard_utils.GetNonPod2dMesh(DEVICE_MESH_SHAPE, [16, 16, 2])
class DenseLm175B8x8Decode2D(DenseLm175B32x32):
    """175B params LM model decoding on v3-128.

  2D logical mesh. It can load a checkpoint from DenseLm175B32x32.
  """
    BATCH_DIM_PER_DEVICE = 0.125
    NUM_DEVICES_PER_SPLIT = 128
    # NUM_HEADS is not a multiple of 128 so we use 2D sharding on M and H.
    DEVICE_MESH_SHAPE = [8, 16]
    DEVICE_MESH = gshard_utils.GetNonPod2dMesh(DEVICE_MESH_SHAPE, [8, 8, 2])

    def Task(self):
        p = super().Task()
        # relative_attention_use_universal_1d_position should be set to False in
        # decoding.
        p.builder.relative_attention_use_universal_1d_position = False
        p.builder.model_dim_reshape_segments = self.DEVICE_MESH_SHAPE[0]
        p.builder.emb_w_split = [1, 0]
        p.builder.emb_out_split = [-1, -1, 0]
        p.builder.blm_split = [-1, -1, 0]
        p.builder.blh_split = [-1, -1, 1]
        p.builder.qkv_split = [0, -1, 1,
                               -1]  # [-1, -1, 1, -1] for global batch 1.
        p.builder.logits_split = [-1, -1, 1]
        return p
class DenseLm175B32x32DP(DenseLm175B32x32):
    """175B model running on v3-2048 with 2-way data parallelism."""
    NUM_DEVICES_PER_SPLIT = 1024
    TRAIN_STEPS_PER_LOOP = 20
    DEVICE_MESH_SHAPE = [64, 16]
    DEVICE_MESH = gshard_utils.GetNonPod2dMesh([16, 64],
                                               [16, 32, 2]).transpose()
    MODEL_DIM_RESHAPE_SEGMENTS = [16]
class DenseLm128B8x8(DenseLmTemplate):
    """128B params LM model with 2D split."""
    SEQUENCE_LENGTH = 1024
    NUM_DEVICES_PER_SPLIT = 128
    BATCH_DIM_PER_DEVICE = 0.125
    NUM_TRANSFORMER_LAYERS = 64  # 64 blocks of [DecSelfAttention, DenseReluDense]
    DEVICE_MESH_SHAPE = [8, 16]
    DEVICE_MESH = gshard_utils.GetNonPod2dMesh(DEVICE_MESH_SHAPE, [8, 8, 2])

    def Task(self):
        p = super().Task()
        p.train.tpu_device_order_mode = 2  # DeviceOrderMode.MESH
        p.builder.model_dim_reshape_segments = self.DEVICE_MESH_SHAPE[1]
        p.builder.emb_w_split = [-1, 1]
        p.builder.emb_out_split = [0, -1, 1]
        p.builder.blm_split = [0, -1, 1]
        p.builder.logits_split = [0, -1, 1]
        return p