def slice_state_dict(config, loaded_state_dict): sliced_state_dict = OrderedDict() start_layer_id = ( config.n_total_layers // mpu.get_pipeline_parallel_world_size() * mpu.get_pipeline_parallel_group_rank() + min(mpu.get_pipeline_parallel_group_rank(), config.n_total_layers % mpu.get_pipeline_parallel_world_size())) end_layer_id = start_layer_id + config.n_layers for key, value in loaded_state_dict.items(): keys = key.split('.') global_layer_id = int(keys[2]) if start_layer_id <= global_layer_id < end_layer_id: local_layer_id = global_layer_id - start_layer_id new_key = '.'.join(keys[:2] + [str(local_layer_id)] + keys[3:]) if keys[3] == 'attn' and keys[4] == 'in_proj': in_size = mpu.divide(value.size(0), mpu.get_model_parallel_world_size()) if keys[5] in ('weight', 'bias'): new_value = value[mpu.get_model_parallel_rank() * in_size:(mpu.get_model_parallel_rank() + 1) * in_size] else: raise NotImplementedError(f"Unknown key {key}") elif keys[3] == 'attn' and keys[4] == 'out_proj': if keys[5] == 'weight': out_size = mpu.divide(value.size(1), mpu.get_model_parallel_world_size()) new_value = value[:, mpu.get_model_parallel_rank() * out_size:(mpu.get_model_parallel_rank() + 1) * out_size] elif keys[5] == 'bias': new_value = value else: raise NotImplementedError(f"Unknown key {key}") elif keys[3] == 'fc1': in_size = mpu.divide(value.size(0), mpu.get_model_parallel_world_size()) if keys[4] in ('weight', 'bias'): new_value = value[mpu.get_model_parallel_rank() * in_size:(mpu.get_model_parallel_rank() + 1) * in_size] else: raise NotImplementedError(f"Unknown key {key}") elif keys[3] == 'fc2': if keys[4] == 'weight': out_size = mpu.divide(value.size(1), mpu.get_model_parallel_world_size()) new_value = value[:, mpu.get_model_parallel_rank() * out_size:(mpu.get_model_parallel_rank() + 1) * out_size] elif keys[4] == 'bias': new_value = value else: raise NotImplementedError(f"Unknown key {key}") else: new_value = value sliced_state_dict[new_key] = new_value return sliced_state_dict
def measure_iteration_time(args, n_warmup_steps=2): config, layers, pipelined_layers = initialize_model(args) optimizer = torch.optim.Adam(pipelined_layers.parameters(), lr=1e-10) step_times = [] for _ in range(args.n_steps + n_warmup_steps): start_time = time.time() optimizer.zero_grad() if mpu.get_pipeline_parallel_group_rank() == 0: x = layers.create_inputs(config.batch_size, config.seq_len, random=True) else: x = None try: y = pipelined_layers(x) if mpu.get_pipeline_parallel_group_rank( ) == mpu.get_pipeline_parallel_world_size() - 1: loss = loss_func(y) loss.backward() else: y.backward() except: print(f"rank={args.rank}", traceback.format_exc()) raise optimizer.step() step_time = time.time() - start_time step_times.append(step_time) step_times = np.array(step_times)[n_warmup_steps:] return np.mean(step_times), np.std(step_times)
def forward(self, inputs=None): if inputs is None: assert mpu.get_pipeline_parallel_group_rank() > 0 inputs = self.layers.create_inputs(self.batch_size, self.seq_len) inputs = grid_slice_batch_and_sequence(inputs, batch_slices=self.batch_slices, seq_slices=self.seq_slices, batch_dim=self.batch_dim, sequence_dim=self.sequence_dim, requires_grad=True) cache_inputs = np.empty((self.n_batch_slices, self.n_input_slices), dtype='O') cache_outputs = np.empty((self.n_batch_slices, self.n_input_slices), dtype='O') outputs = np.empty((self.n_batch_slices, self.n_input_slices), dtype='O') for batch_id in range(self.n_batch_slices): slice_batch_size = inputs[batch_id, 0].size(self.batch_dim) cache = self.layers.create_cache(slice_batch_size, self.seq_len) cache_len = 0 for input_id in range(self.n_input_slices): x = inputs[batch_id, input_id] x = pipeline_recv(x) slice_seq_len = x.size(self.sequence_dim) cache = [c.detach().requires_grad_() for c in cache] y, cache_output = self.layers(x, cache, cache_len) y = pipeline_send(y) cache_inputs[batch_id, input_id] = cache cache_outputs[batch_id, input_id] = cache_output cache = cache_output outputs[batch_id, input_id] = y cache_len += slice_seq_len y = terapipe_backward_hook( outputs, cache_inputs, cache_outputs, self.batch_slices, self.seq_slices, self.batch_dim, self.sequence_dim, cat_outputs=(mpu.get_pipeline_parallel_group_rank() == mpu.get_pipeline_parallel_world_size() - 1)) return y
def initialize_model(args): config = TransformerConfig.from_predefined_model( args.model, batch_size=args.batch_size) config.n_total_layers = config.n_layers config.n_layers = ( config.n_total_layers // mpu.get_pipeline_parallel_world_size() + int(mpu.get_pipeline_parallel_group_rank() < config.n_total_layers % mpu.get_pipeline_parallel_world_size())) layers = TransformerLayers(config.n_layers, config.embedding_dim, config.ffn_embedding_dim, config.num_attention_heads, mixed_precision=args.mixed_precision) seq_slices = uniform_slice(config.seq_len, args.n_input_slices) batch_slices = uniform_slice(config.batch_size, args.n_batch_slices) pipelined_layers = TeraPipe(layers, config.batch_size, config.seq_len, batch_slices, seq_slices) return config, layers, pipelined_layers
def __init__(self, config, batch_slices, seq_slices, distributed_init_method, world_size, data_parallel_size, model_parallel_size, pipeline_parallel_size, rank, local_rank, mixed_precision=False, use_mpi=False, init_process_group=False, checkpoint_gradients=False): self.config = config self.batch_slices = batch_slices self.seq_slices = seq_slices torch.cuda.set_device(local_rank) if init_process_group: dist.init_process_group( backend='nccl', init_method=distributed_init_method, world_size=world_size, rank=rank, ) dist.all_reduce(torch.zeros(1).cuda()) mpu.initialize_model_parallel(model_parallel_size, pipeline_parallel_size) set_random_seed(0) mpu.model_parallel_cuda_manual_seed(0) self.rank = rank self.local_rank = local_rank self.world_size = world_size self.data_parallel_size = data_parallel_size self.model_parallel_size = model_parallel_size self.pipeline_parallel_size = pipeline_parallel_size self.pipeline_parallel_group_rank = mpu.get_pipeline_parallel_group_rank( ) self.data_parallel_group = mpu.get_data_parallel_group() self.model_parallel_group = mpu.get_model_parallel_group() self.pipeline_parallel_pred_group = mpu.get_pipeline_parallel_pred_group( ) self.pipeline_parallel_succ_group = mpu.get_pipeline_parallel_succ_group( ) self.model_parallel_src_rank = mpu.get_model_parallel_src_rank() self.model_parallel_dst_rank = mpu.get_model_parallel_dst_rank() self.model_parallel_next_src_rank = ( self.model_parallel_src_rank + self.model_parallel_size if self.pipeline_parallel_group_rank < self.pipeline_parallel_size - 1 else None) self.model_parallel_prev_dst_rank = ( self.model_parallel_dst_rank - self.model_parallel_size if self.pipeline_parallel_group_rank > 0 else None) self.n_layers = (config.n_layers // pipeline_parallel_size + int(rank < config.n_layers % pipeline_parallel_size)) self.config = config self.mixed_precision = mixed_precision self.checkpoint_gradients = checkpoint_gradients self.layers = [] for _ in range(self.n_layers): l = ModelParallelTransformerLayer( self.config.embedding_dim, self.config.ffn_embedding_dim, self.config.num_attention_heads, device="cuda", checkpoint_gradients=self.checkpoint_gradients) self.layers.append(l.half() if self.mixed_precision else l) self.all_parameters = [] for layer in self.layers: self.all_parameters.extend(layer.parameters()) self.n_params = len(self.all_parameters) if self.mixed_precision: self.master_parameters = [ p.clone().detach().float() for p in self.all_parameters ] for p in self.master_parameters: p.requires_grad_() self.optimizer = optimizers.FusedAdam(self.master_parameters, lr=1e-10) else: self.optimizer = torch.optim.Adam(self.all_parameters, lr=1e-10)
def verify_one_step(args): if args.verify == "save": assert dist.get_world_size() == 1 assert mpu.get_pipeline_parallel_world_size() == 1 assert mpu.get_model_parallel_world_size() == 1 assert args.n_input_slices == 1 assert args.n_batch_slices == 1 os.makedirs(args.verify_path, exist_ok=True) config, layers, pipelined_layers = initialize_model(args) if mpu.get_pipeline_parallel_group_rank() == 0: x = layers.create_inputs(config.batch_size, config.seq_len, random=True) torch.save(x, os.path.join(args.verify_path, 'input.pt')) else: x = None try: y = pipelined_layers(x) if mpu.get_pipeline_parallel_group_rank( ) == mpu.get_pipeline_parallel_world_size() - 1: loss = loss_func(y) loss.backward() else: y.backward() except: print(f"rank={args.rank}", traceback.format_exc()) raise torch.save(pipelined_layers.state_dict(), os.path.join(args.verify_path, 'model.ckpt')) grad_dic = OrderedDict( (x[0], x[1].grad) for x in pipelined_layers.named_parameters()) torch.save(grad_dic, os.path.join(args.verify_path, 'model.grad.ckpt')) else: assert args.verify == "load" config, layers, pipelined_layers = initialize_model(args) with FileLock(os.path.join(args.verify_path, 'model.ckpt.lock')): loaded_state_dict = torch.load(os.path.join( args.verify_path, 'model.ckpt'), map_location=torch.device('cuda')) sliced_state_dict = slice_state_dict(config, loaded_state_dict) pipelined_layers.load_state_dict(sliced_state_dict) if mpu.get_pipeline_parallel_group_rank() == 0: with FileLock(os.path.join(args.verify_path, 'input.pt.lock')): x = torch.load(os.path.join(args.verify_path, 'input.pt'), map_location=torch.device('cuda')) else: x = None try: y = pipelined_layers(x) if mpu.get_pipeline_parallel_group_rank( ) == mpu.get_pipeline_parallel_world_size() - 1: loss = loss_func(y) loss.backward() else: y.backward() except: print(f"rank={args.rank}", traceback.format_exc()) raise grad_dic = OrderedDict( (x[0], x[1].grad) for x in pipelined_layers.named_parameters()) with FileLock(os.path.join(args.verify_path, 'model.grad.ckpt.lock')): loaded_grad_dic = torch.load(os.path.join(args.verify_path, 'model.grad.ckpt'), map_location=torch.device('cuda')) sliced_grad_dic = slice_state_dict(config, loaded_grad_dic) assert grad_dic.keys() == sliced_grad_dic.keys() for k in grad_dic.keys(): assert torch.allclose(grad_dic[k], sliced_grad_dic[k])