示例#1
0
def slice_state_dict(config, loaded_state_dict):
    sliced_state_dict = OrderedDict()
    start_layer_id = (
        config.n_total_layers // mpu.get_pipeline_parallel_world_size() *
        mpu.get_pipeline_parallel_group_rank() +
        min(mpu.get_pipeline_parallel_group_rank(),
            config.n_total_layers % mpu.get_pipeline_parallel_world_size()))
    end_layer_id = start_layer_id + config.n_layers
    for key, value in loaded_state_dict.items():
        keys = key.split('.')
        global_layer_id = int(keys[2])
        if start_layer_id <= global_layer_id < end_layer_id:
            local_layer_id = global_layer_id - start_layer_id
            new_key = '.'.join(keys[:2] + [str(local_layer_id)] + keys[3:])
            if keys[3] == 'attn' and keys[4] == 'in_proj':
                in_size = mpu.divide(value.size(0),
                                     mpu.get_model_parallel_world_size())
                if keys[5] in ('weight', 'bias'):
                    new_value = value[mpu.get_model_parallel_rank() *
                                      in_size:(mpu.get_model_parallel_rank() +
                                               1) * in_size]
                else:
                    raise NotImplementedError(f"Unknown key {key}")
            elif keys[3] == 'attn' and keys[4] == 'out_proj':
                if keys[5] == 'weight':
                    out_size = mpu.divide(value.size(1),
                                          mpu.get_model_parallel_world_size())
                    new_value = value[:,
                                      mpu.get_model_parallel_rank() *
                                      out_size:(mpu.get_model_parallel_rank() +
                                                1) * out_size]
                elif keys[5] == 'bias':
                    new_value = value
                else:
                    raise NotImplementedError(f"Unknown key {key}")
            elif keys[3] == 'fc1':
                in_size = mpu.divide(value.size(0),
                                     mpu.get_model_parallel_world_size())
                if keys[4] in ('weight', 'bias'):
                    new_value = value[mpu.get_model_parallel_rank() *
                                      in_size:(mpu.get_model_parallel_rank() +
                                               1) * in_size]
                else:
                    raise NotImplementedError(f"Unknown key {key}")
            elif keys[3] == 'fc2':
                if keys[4] == 'weight':
                    out_size = mpu.divide(value.size(1),
                                          mpu.get_model_parallel_world_size())
                    new_value = value[:,
                                      mpu.get_model_parallel_rank() *
                                      out_size:(mpu.get_model_parallel_rank() +
                                                1) * out_size]
                elif keys[4] == 'bias':
                    new_value = value
                else:
                    raise NotImplementedError(f"Unknown key {key}")
            else:
                new_value = value
            sliced_state_dict[new_key] = new_value
    return sliced_state_dict
示例#2
0
def measure_iteration_time(args, n_warmup_steps=2):
    config, layers, pipelined_layers = initialize_model(args)
    optimizer = torch.optim.Adam(pipelined_layers.parameters(), lr=1e-10)
    step_times = []
    for _ in range(args.n_steps + n_warmup_steps):
        start_time = time.time()
        optimizer.zero_grad()
        if mpu.get_pipeline_parallel_group_rank() == 0:
            x = layers.create_inputs(config.batch_size,
                                     config.seq_len,
                                     random=True)
        else:
            x = None
        try:
            y = pipelined_layers(x)
            if mpu.get_pipeline_parallel_group_rank(
            ) == mpu.get_pipeline_parallel_world_size() - 1:
                loss = loss_func(y)
                loss.backward()
            else:
                y.backward()
        except:
            print(f"rank={args.rank}", traceback.format_exc())
            raise
        optimizer.step()
        step_time = time.time() - start_time
        step_times.append(step_time)
    step_times = np.array(step_times)[n_warmup_steps:]
    return np.mean(step_times), np.std(step_times)
示例#3
0
    def forward(self, inputs=None):
        if inputs is None:
            assert mpu.get_pipeline_parallel_group_rank() > 0
            inputs = self.layers.create_inputs(self.batch_size, self.seq_len)
        inputs = grid_slice_batch_and_sequence(inputs,
                                               batch_slices=self.batch_slices,
                                               seq_slices=self.seq_slices,
                                               batch_dim=self.batch_dim,
                                               sequence_dim=self.sequence_dim,
                                               requires_grad=True)
        cache_inputs = np.empty((self.n_batch_slices, self.n_input_slices),
                                dtype='O')
        cache_outputs = np.empty((self.n_batch_slices, self.n_input_slices),
                                 dtype='O')
        outputs = np.empty((self.n_batch_slices, self.n_input_slices),
                           dtype='O')

        for batch_id in range(self.n_batch_slices):
            slice_batch_size = inputs[batch_id, 0].size(self.batch_dim)
            cache = self.layers.create_cache(slice_batch_size, self.seq_len)
            cache_len = 0
            for input_id in range(self.n_input_slices):
                x = inputs[batch_id, input_id]
                x = pipeline_recv(x)
                slice_seq_len = x.size(self.sequence_dim)
                cache = [c.detach().requires_grad_() for c in cache]
                y, cache_output = self.layers(x, cache, cache_len)
                y = pipeline_send(y)
                cache_inputs[batch_id, input_id] = cache
                cache_outputs[batch_id, input_id] = cache_output
                cache = cache_output
                outputs[batch_id, input_id] = y
                cache_len += slice_seq_len

        y = terapipe_backward_hook(
            outputs,
            cache_inputs,
            cache_outputs,
            self.batch_slices,
            self.seq_slices,
            self.batch_dim,
            self.sequence_dim,
            cat_outputs=(mpu.get_pipeline_parallel_group_rank() ==
                         mpu.get_pipeline_parallel_world_size() - 1))

        return y
示例#4
0
def initialize_model(args):
    config = TransformerConfig.from_predefined_model(
        args.model, batch_size=args.batch_size)
    config.n_total_layers = config.n_layers
    config.n_layers = (
        config.n_total_layers // mpu.get_pipeline_parallel_world_size() +
        int(mpu.get_pipeline_parallel_group_rank() < config.n_total_layers %
            mpu.get_pipeline_parallel_world_size()))

    layers = TransformerLayers(config.n_layers,
                               config.embedding_dim,
                               config.ffn_embedding_dim,
                               config.num_attention_heads,
                               mixed_precision=args.mixed_precision)
    seq_slices = uniform_slice(config.seq_len, args.n_input_slices)
    batch_slices = uniform_slice(config.batch_size, args.n_batch_slices)
    pipelined_layers = TeraPipe(layers, config.batch_size, config.seq_len,
                                batch_slices, seq_slices)
    return config, layers, pipelined_layers
    def __init__(self,
                 config,
                 batch_slices,
                 seq_slices,
                 distributed_init_method,
                 world_size,
                 data_parallel_size,
                 model_parallel_size,
                 pipeline_parallel_size,
                 rank,
                 local_rank,
                 mixed_precision=False,
                 use_mpi=False,
                 init_process_group=False,
                 checkpoint_gradients=False):
        self.config = config
        self.batch_slices = batch_slices
        self.seq_slices = seq_slices
        torch.cuda.set_device(local_rank)
        if init_process_group:
            dist.init_process_group(
                backend='nccl',
                init_method=distributed_init_method,
                world_size=world_size,
                rank=rank,
            )
        dist.all_reduce(torch.zeros(1).cuda())
        mpu.initialize_model_parallel(model_parallel_size,
                                      pipeline_parallel_size)
        set_random_seed(0)
        mpu.model_parallel_cuda_manual_seed(0)
        self.rank = rank
        self.local_rank = local_rank
        self.world_size = world_size
        self.data_parallel_size = data_parallel_size
        self.model_parallel_size = model_parallel_size
        self.pipeline_parallel_size = pipeline_parallel_size
        self.pipeline_parallel_group_rank = mpu.get_pipeline_parallel_group_rank(
        )
        self.data_parallel_group = mpu.get_data_parallel_group()
        self.model_parallel_group = mpu.get_model_parallel_group()
        self.pipeline_parallel_pred_group = mpu.get_pipeline_parallel_pred_group(
        )
        self.pipeline_parallel_succ_group = mpu.get_pipeline_parallel_succ_group(
        )
        self.model_parallel_src_rank = mpu.get_model_parallel_src_rank()
        self.model_parallel_dst_rank = mpu.get_model_parallel_dst_rank()
        self.model_parallel_next_src_rank = (
            self.model_parallel_src_rank + self.model_parallel_size if
            self.pipeline_parallel_group_rank < self.pipeline_parallel_size - 1
            else None)
        self.model_parallel_prev_dst_rank = (
            self.model_parallel_dst_rank - self.model_parallel_size
            if self.pipeline_parallel_group_rank > 0 else None)

        self.n_layers = (config.n_layers // pipeline_parallel_size +
                         int(rank < config.n_layers % pipeline_parallel_size))
        self.config = config
        self.mixed_precision = mixed_precision
        self.checkpoint_gradients = checkpoint_gradients

        self.layers = []
        for _ in range(self.n_layers):
            l = ModelParallelTransformerLayer(
                self.config.embedding_dim,
                self.config.ffn_embedding_dim,
                self.config.num_attention_heads,
                device="cuda",
                checkpoint_gradients=self.checkpoint_gradients)
            self.layers.append(l.half() if self.mixed_precision else l)

        self.all_parameters = []
        for layer in self.layers:
            self.all_parameters.extend(layer.parameters())
        self.n_params = len(self.all_parameters)

        if self.mixed_precision:
            self.master_parameters = [
                p.clone().detach().float() for p in self.all_parameters
            ]
            for p in self.master_parameters:
                p.requires_grad_()
            self.optimizer = optimizers.FusedAdam(self.master_parameters,
                                                  lr=1e-10)
        else:
            self.optimizer = torch.optim.Adam(self.all_parameters, lr=1e-10)
示例#6
0
def verify_one_step(args):
    if args.verify == "save":
        assert dist.get_world_size() == 1
        assert mpu.get_pipeline_parallel_world_size() == 1
        assert mpu.get_model_parallel_world_size() == 1
        assert args.n_input_slices == 1
        assert args.n_batch_slices == 1
        os.makedirs(args.verify_path, exist_ok=True)
        config, layers, pipelined_layers = initialize_model(args)
        if mpu.get_pipeline_parallel_group_rank() == 0:
            x = layers.create_inputs(config.batch_size,
                                     config.seq_len,
                                     random=True)
            torch.save(x, os.path.join(args.verify_path, 'input.pt'))
        else:
            x = None
        try:
            y = pipelined_layers(x)
            if mpu.get_pipeline_parallel_group_rank(
            ) == mpu.get_pipeline_parallel_world_size() - 1:
                loss = loss_func(y)
                loss.backward()
            else:
                y.backward()
        except:
            print(f"rank={args.rank}", traceback.format_exc())
            raise
        torch.save(pipelined_layers.state_dict(),
                   os.path.join(args.verify_path, 'model.ckpt'))
        grad_dic = OrderedDict(
            (x[0], x[1].grad) for x in pipelined_layers.named_parameters())
        torch.save(grad_dic, os.path.join(args.verify_path, 'model.grad.ckpt'))
    else:
        assert args.verify == "load"
        config, layers, pipelined_layers = initialize_model(args)
        with FileLock(os.path.join(args.verify_path, 'model.ckpt.lock')):
            loaded_state_dict = torch.load(os.path.join(
                args.verify_path, 'model.ckpt'),
                                           map_location=torch.device('cuda'))
        sliced_state_dict = slice_state_dict(config, loaded_state_dict)
        pipelined_layers.load_state_dict(sliced_state_dict)
        if mpu.get_pipeline_parallel_group_rank() == 0:
            with FileLock(os.path.join(args.verify_path, 'input.pt.lock')):
                x = torch.load(os.path.join(args.verify_path, 'input.pt'),
                               map_location=torch.device('cuda'))
        else:
            x = None
        try:
            y = pipelined_layers(x)
            if mpu.get_pipeline_parallel_group_rank(
            ) == mpu.get_pipeline_parallel_world_size() - 1:
                loss = loss_func(y)
                loss.backward()
            else:
                y.backward()
        except:
            print(f"rank={args.rank}", traceback.format_exc())
            raise
        grad_dic = OrderedDict(
            (x[0], x[1].grad) for x in pipelined_layers.named_parameters())
        with FileLock(os.path.join(args.verify_path, 'model.grad.ckpt.lock')):
            loaded_grad_dic = torch.load(os.path.join(args.verify_path,
                                                      'model.grad.ckpt'),
                                         map_location=torch.device('cuda'))
        sliced_grad_dic = slice_state_dict(config, loaded_grad_dic)
        assert grad_dic.keys() == sliced_grad_dic.keys()
        for k in grad_dic.keys():
            assert torch.allclose(grad_dic[k], sliced_grad_dic[k])