optim.step() if torch.distributed.get_rank() == 0: print("finished iter", i) runtime += time.time() - since return runtime / 3.0 if __name__ == "__main__": init = True for async_comm in (False, True): global fancy_data global effective_length if init: init = False global_vars.set_global_variables() fancy_data = download_fancy_data() args = global_vars.get_args() effective_length = fancy_data.size(0) // args.seq_length effective_length = fancy_data.size(0) - args.seq_length initialize_distributed("nccl") world_size = torch.distributed.get_world_size() failure = None args.padded_vocab_size = 128 batch_size = args.global_batch_size micro_batch_size = args.micro_batch_size setup_microbatch_calculator( args.rank,
from apex.transformer.testing.commons import print_separator from apex.transformer.testing.commons import fwd_step_func from apex.transformer.log_util import get_transformer_logger, set_logging_level from apex.transformer.testing.commons import model_provider_func from apex.transformer._data import MegatronPretrainingRandomSampler from apex.transformer._data import MegatronPretrainingSampler # note(mkozuki): To see warmup, steady, cooldown iterations, uncomment the line below # set_logging_level("INFO") _logger = get_transformer_logger("pipeline_parallel_test") # note(mkozuki): To see if local batch size increases, uncomment the line below # _logger.setLevel("INFO") global_vars.set_global_variables( args_defaults={ "global_batch_size": 512, "rampup_batch_size": [64, 64, 1000], }, ignore_unknown_args=True, ) RAMPUP_BATCH_SIZE = [] NUM_ITERATIONS = 20 NUM_SAMPLES = 16384 // 2 batch_size, micro_batch_size = None, None HIDDEN_SIZE = 16 def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: return [( torch.randn(HIDDEN_SIZE, HIDDEN_SIZE), torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2),