def initialize_smp(smp_args, training_args): smp_config = { "ddp": smp_args.ddp, "pipeline_parallel_degree": smp_args.pipeline_parallel_degree, "microbatches": smp_args.microbatches, "shard_optimizer_state": smp_args.shard_optimizer_state > 0, "prescaled_batch": smp_args.prescaled_batch > 0, "_match_weights": smp_args.match_weights > 0, "offload_activations": smp_args.offload_activations > 0, "optimize": smp_args.optimize, "auto_partition": True, "default_partition": 0, "static_mode": smp_args.static_mode > 0, "fast_mode": smp_args.fast_mode > 0, } if smp_args.active_microbatches is not None: smp_config["active_microbatches"] = smp_args.active_microbatches smp.init(smp_config) if smp.rank() == 0: print("Arguments:", smp_args.__dict__) print(f"Transformers version: {transformers.__version__}") print( f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}" ) print(f"smdistributed config: {smp_config}") set_seed(training_args.seed)
def setup_training(args): assert torch.cuda.is_available() if args.smp > 0: # Initialize SMP. The configuration is obtained from the parameters passed to # the Sagemaker PyTorch estimator. smp.init() # SMP: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda", smp.local_rank()) args.n_gpu = 1 # if args.local_rank == -1: # device = torch.device("cuda") # args.n_gpu = torch.cuda.device_count() # args.allreduce_post_accumulation = False # args.allreduce_post_accumulation_fp16 = False # else: # torch.cuda.set_device(args.local_rank) # device = torch.device("cuda", args.local_rank) # # Initializes the distributed backend which will take care of sychronizing nodes/GPUs # torch.distributed.init_process_group(backend='nccl', init_method='env://') # args.n_gpu = 1 if args.gradient_accumulation_steps == 1: args.allreduce_post_accumulation = False args.allreduce_post_accumulation_fp16 = False print( "device: {} n_gpu: {}, mp_rank: {}, rank: {}, distributed training: {}, 16-bits training: {}" .format(device, args.n_gpu, smp.mp_rank(), smp.rank(), bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) if args.train_batch_size % args.gradient_accumulation_steps != 0: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible" .format(args.gradient_accumulation_steps, args.train_batch_size)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps if (not args.resume_from_checkpoint and os.path.exists(args.output_dir) and (os.listdir(args.output_dir) and any([i.startswith("ckpt") for i in os.listdir(args.output_dir)]))): raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) if (not args.resume_from_checkpoint or not os.path.exists(args.output_dir)) and is_main_process(): os.makedirs(args.output_dir, exist_ok=True) return device, args
def smp_init(args): args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) args.rank = smp.dp_rank() args.global_rank = smp.rank() args.world_size = smp.size() os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) os.environ['LOCAL_RANK'] = str(smp.local_rank()) # ## SMP_SKIP_GRAPH_VALIDATION=1 os.environ['SMP_SKIP_GRAPH_VALIDATION'] = "0" # args.bpe_path = "/opt/ml/code/dalle_pytorch/data/bpe_simple_vocab_16e6.txt" torch.cuda.set_device(smp.local_rank()) args.local_rank = smp.local_rank() # if args.seed is not None: # random.seed(args.seed) # torch.manual_seed(args.seed+args.rank) # np.random.seed(args.seed) # torch.cuda.manual_seed_all(args.seed) # cudnn.deterministic = True # if cudnn.deterministic: # warnings.warn('You have chosen to seed training. ' # 'This will turn on the CUDNN deterministic setting, ' # 'which can slow down your training considerably! ' # 'You may see unexpected behavior when restarting ' # 'from checkpoints.') return args
def dist_init(fn, args): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.deterministic = True if cudnn.deterministic: warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.is_distributed = len(args.hosts) > 1 and args.backend is not None args.is_multigpus = args.num_gpus > 1 args.multigpus_distributed = (args.is_distributed or args.is_multigpus) logger.debug("multigpus_distributed - {}".format( args.multigpus_distributed)) logger.debug("Number of gpus available - {}".format(args.num_gpus)) # print("######### Start Training #########") if args.multigpus_distributed: if args.apex: # Initialize the distributed environment. mp.spawn(fn, nprocs=args.num_gpus, args=(args, )) else: if args.data_parallel and not sdp.is_initialized(): sdp.init_process_group() elif args.model_parallel and not smp.is_initialized(): smp.init() fn(None, args) if args.model_parallel: smp.barrier() else: fn(0, args)
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = args.ddp > 0 use_horovod = args.horovod > 0 # Fix seeds in order to get the same losses across runs random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) cfg = { "microbatches": args.num_microbatches, "placement_strategy": "spread", "pipeline": args.pipeline, "optimize": "speed", "partitions": args.num_partitions, "horovod": use_horovod, "ddp": use_ddp, } smp.init(cfg) # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": args.batch_size} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) if args.partial_checkpoint: checkpoint = smp.load(args.partial_checkpoint, partial=True) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: checkpoint = smp.load(args.full_checkpoint, partial=False) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): train(args, model, scaler, device, train_loader, optimizer, epoch) test_loss = test(args, model, device, test_loader) scheduler.step() if args.save_partial_model: if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"./pt_mnist_checkpoint.pt", partial=True, ) if args.save_full_model: if smp.dp_rank() == 0: model_dict = model.state_dict() opt_dict = optimizer.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, "./pt_mnist_checkpoint.pt", partial=False, ) # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if args.assert_losses: if use_horovod or use_ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(test_loss, smp.DP_GROUP) for l in losses: assert math.isclose(l, losses[0]) assert test_loss < 0.18 else: assert test_loss < 0.08
def __post_init__(self): super().__post_init__() if is_smdistributed_available() and self.mp_parameters != "": smp.init()
def main(): if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = True use_horovod = False # Fix seeds in order to get the same losses across runs random.seed(1) np.random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed(1) smp.init() # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": 64} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=4.0) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) for epoch in range(1, 2): train(model, scaler, device, train_loader, optimizer, epoch) test_loss = test(model, device, test_loader) scheduler.step() if smp.rank() == 0: if os.path.exists("/opt/ml/local_checkpoints"): print("-INFO- PATH DO EXIST") else: os.makedirs("/opt/ml/local_checkpoints") print("-INFO- PATH DO NOT EXIST") # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"/opt/ml/local_checkpoints/pt_mnist_checkpoint.pt", partial=True, ) smp.barrier() if smp.local_rank() == 0: print("Start syncing") base_s3_path = os.path.dirname( os.path.dirname(os.getenv("SM_MODULE_DIR", ""))) curr_host = os.getenv("SM_CURRENT_HOST") full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/" sync_local_checkpoints_to_s3(local_path="/opt/ml/local_checkpoints", s3_path=full_s3_path) print("Finished syncing")
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption from .utils import logging if is_torch_available(): import torch if is_torch_tpu_available(): import torch_xla.core.xla_model as xm if is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.distributed as sm_dist if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp smp.init() logger = logging.get_logger(__name__) def default_logdir() -> str: """ Same default as PyTorch """ import socket from datetime import datetime current_time = datetime.now().strftime("%b%d_%H-%M-%S") return os.path.join("runs", current_time + "_" + socket.gethostname())
def main(): args = parse_args() if args.shard_optimizer_state > 0 and not args.skip_full_optimizer: raise ValueError( "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding." ) if args.partition_assignment != "" and args.manual_partition == 0: print("[Warning] partition_assignment is set, enable manual_partition") args.manual_partition = 1 # any value here is overriden by the config set in notebook when launching the sagemaker job smp_config = { "ddp": True, "tensor_parallel_degree": args.tensor_parallel_degree, "pipeline_parallel_degree": args.pipeline_parallel_degree, "microbatches": args.microbatches, # if activation_checkpointing true checkpoints transformer layers below "checkpoint_attentions": False if args.activation_checkpointing else True, "shard_optimizer_state": args.shard_optimizer_state > 0, "prescaled_batch": args.prescaled_batch > 0, "offload_activations": args.offload_activations > 0, "optimize": args.optimize, "auto_partition": False if args.manual_partition else True, "default_partition": 0, "static_mode": args.static_mode > 0, "fast_mode": args.fast_mode > 0, } if args.smp_version < 110: smp_config["fp16_params"] = args.fp16 > 0 else: smp_config["fp16"] = args.fp16 > 0 smp_config["delayed_parameter_initialization"] = args.delayed_param > 0 smp_config["placement_strategy"] = args.placement_strategy smp_config[ "activation_loading_horizon"] = args.activation_loading_horizon smp_config["skip_tracing"] = args.skip_tracing > 0 if args.active_microbatches is not None: smp_config["active_microbatches"] = args.active_microbatches smp.init(smp_config) if smp.rank() == 0: print("Arguments:", args.__dict__) print(f"Transformers version: {transformers.__version__}") print( f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}" ) print(f"smdistributed config: {smp_config}") if args.save_final_full_model and smp.rank() == 0: print( f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints." ) if args.partition_assignment != "": partition_assignment = args.partition_assignment.split(",") assert ( len(partition_assignment) == smp.pp_size() ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}" if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0): for path in [args.model_dir, args.checkpoint_dir]: if not os.path.exists(path): os.makedirs(path, exist_ok=True) model_config = GPT2Config( vocab_size=args.vocab_size, n_positions=args.max_context_width, n_embd=args.hidden_width, n_layer=args.num_layers, n_head=args.num_heads, n_inner=None, activation_function="gelu_new", resid_pdrop=args.resid_pdrop, embd_pdrop=args.embd_pdrop, attn_pdrop=args.attn_pdrop, layer_norm_epsilon=1e-05, initializer_range=0.02, summary_type="cls_index", summary_use_proj=True, summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=args.summary_first_pdrop, # gradient_checkpointing=args.gradient_checkpointing > 0, use_cache=False, bos_token_id=50256, eos_token_id=50256, return_dict=True, ) # the following improves start-up time by skipping proper initialization # of weights in the original model. this is not a problem because DistributedModel # will override those weights anyway when tensor_parallel_degree > 1. if smp.tp_size() > 1: from transformers.modeling_utils import PreTrainedModel PreTrainedModel.init_weights = lambda x: None set_seed(args.seed) if args.enable_memory_profiling > 0: memory_status_cpu(msg="before model creation") if args.smp_version < 110: if args.fp16: torch.set_default_dtype(torch.float16) with smp.tensor_parallelism( enabled=smp.tp_size() > 1, attention_in_fp32=args.attention_in_fp32 > 0): with smp.delay_param_initialization( enabled=(smp.tp_size() > 1 and args.delayed_param > 0)): model = AutoModelForCausalLM.from_config(model_config) else: with smp.model_creation( tensor_parallelism=smp.tp_size() > 1, attention_in_fp32=args.attention_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0, fused_softmax=args.fused_softmax > 0, fused_bias_gelu=args.fused_bias_gelu > 0, dtype=torch.float16 if args.fp16 else torch.get_default_dtype(), ): model = AutoModelForCausalLM.from_config(model_config) if args.smp_version < 110 and args.fp16: model = FP16_Module(model) if args.enable_memory_profiling > 0: memory_status_cpu(msg="after model creation") num_params = sum([np.prod(p.size()) for p in model.parameters()]) if smp.rank() == 0: print(f"# total parameters: {num_params}") # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") if not args.same_seed: # Set seed by tp_rank to prevent weights from being the same on different tp_ranks set_seed(args.seed + smp.tp_rank()) # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. if args.smp_version < 110 and args.fp16: torch.set_default_dtype(torch.float16) if args.enable_memory_profiling > 0: memory_status_cpu(msg="before dist model creation") model = smp.DistributedModel(model, trace_device="gpu") if args.enable_memory_profiling > 0: memory_status_cpu(msg="after dist model creation") if args.smp_version < 110: if smp.tp_size() > 1: transformer_layers = model.module.module.module.transformer.seq_layers else: transformer_layers = model.module.module.module.transformer.h else: m = model.get_module() if smp.tp_size() > 1: transformer_layers = m.transformer.seq_layers else: transformer_layers = m.transformer.h if args.manual_partition: print(f"Manual partition enabled") if args.partition_assignment != "": get_num_layers = lambda x: int(partition_assignment[x]) total_layers = sum( [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())]) assert ( total_layers == args.num_layers ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}" else: # evenly distribute layers across all partitions div, rem = divmod(args.num_layers, smp.pp_size()) get_num_layers = lambda x: (div + 1 if x >= smp.pp_size() - rem else div) assignments = [] # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18" # Need further investigation # for pp_rank in reversed(range(smp.pp_size())): for pp_rank in range(smp.pp_size()): nl = get_num_layers(pp_rank) print(f"{nl} layers assigned to partition {pp_rank}") assignments += [pp_rank for _ in range(nl)] for i, c in enumerate(transformer_layers.children()): smp.set_partition(c, assignments[i]) if args.smp_version < 110: iter_model = model # Build parameter groups (weight decay and non-decay). while isinstance(iter_model, (DistributedDataParallel, FP16_Module)): iter_model = iter_model.module else: iter_model = m param_groups = get_param_groups_by_weight_decay(iter_model) if args.use_adamw > 0: optimizer = optim.AdamW(param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.Adam(param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay) if args.activation_checkpointing: kwargs = {} if isinstance(transformer_layers, nn.Sequential): kwargs["pack_args_as_tuple"] = True kwargs["strategy"] = args.activation_strategy smp.set_activation_checkpointing(transformer_layers, **kwargs) if args.smp_version < 110: optimizer = FP16_Optimizer( model, optimizer, static_loss_scale=None, dynamic_loss_scale=True, use_smp=True, dynamic_loss_args={ "scale_window": 1000, "min_scale": 1, "delayed_shift": 2 }, params_have_main_grad=False, shard_optimizer_state=args.shard_optimizer_state > 0, ) optimizer = smp.DistributedOptimizer(optimizer) model.register_post_step_hook( lambda model, optimizer: optimizer.init_master_params()) else: optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={ "scale_window": 1000, "min_scale": 1, "delayed_shift": 2 }, ) lr_scheduler = get_learning_rate_scheduler(optimizer, args) if args.enable_memory_profiling > 0: model.register_post_partition_hook( lambda model, optimizer: memory_status(msg="After_partition")) # load after wrapping model and optimizer with smp Distributed... if args.load_full or args.load_partial: if args.load_partial and args.load_full: print( "Since both --load_partial and --load_full set, will try to load from full checkpoint." "If the intention is to load from partial checkpoint, please don't set --load_full" ) partial = not args.load_full path = args.checkpoint_dir if partial else args.model_dir translate_from_hf = not partial model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer( path, model, optimizer, lr_scheduler, partial, args, translate_from_hf=translate_from_hf, seq_length=args.max_context_width, load_model=True, load_optimizer=args.load_partial > 0, num_params=num_params, ) else: total_steps = 0 start_train_path_index = 0 start_batch_index = 0 start = time.time() total_steps, throughput, loss = train( model, optimizer, lr_scheduler, model_config, start_train_path_index, start_batch_index, num_params, total_steps, args, ) time_to_train = time.time() - start if args.ci: print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}") print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}") print(f"[SMP_METRIC]__GPT2__Loss__{loss}") if not args.load_partial and not args.load_full: assert time_to_train < args.time_to_train assert throughput > args.throughput if args.loss: assert loss < args.loss if args.save_final_full_model: # saves full model at the end base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" out_path = os.path.join(args.model_dir, base_path) if smp.rdp_rank() == 0: save( out_path, model, optimizer, lr_scheduler, model_config, num_params, total_steps, -1, args, partial=False, translate_to_hf=smp.tp_size() > 1, seq_length=args.max_context_width, ) smp.barrier() if smp.rank() == 0: print("SMP training finished successfully")
def init_train(): """ Train the PyTorch model """ cat_mask = [ False, True, True, True, True, False, True, True, True, True, True, False, False, False, False, False, False, False ] train_ds = CsvDatasetSimple(args.train) test_ds = CsvDatasetSimple(args.test) batch_size = args.batch_size epochs = args.epochs learning_rate = args.learning_rate logger.info("batch_size = {}, epochs = {}, learning rate = {}".format( batch_size, epochs, learning_rate)) # smdistributed: initialize the backend smp.init() # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") # smdistributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time #dataset = datasets.MNIST("../data", train=True, download=False) dataset = train_ds # smdistributed: Shard the dataset based on data-parallel ranks if smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset = SplitDataset(dataset, partitions=partitions_dict) dataset.select(f"{smp.dp_rank()}") # smdistributed: Set drop_last=True to ensure that batch size is always divisible # by the number of microbatches train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True) model = TabularNet(n_cont=9, n_cat=9, cat_mask=cat_mask, cat_dim=[ 0, 2050, 13, 5, 366, 0, 50000, 50000, 50000, 50000, 50, 0, 0, 0, 0, 0, 0, 0 ], y_min=0., y_max=1.) logger.debug(model) optimizer = optim.Adadelta(model.parameters(), lr=4.0) # SMP: Instantiate DistributedModel object using the model. # This handles distributing the model among multiple ranks # behind the scenes # If horovod is enabled this will do an overlapping_all_reduce by # default. # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) optimizer = smp.DistributedOptimizer(optimizer) train(model, device, train_loader, optimizer) torch.save(model.state_dict(), args.model_dir + "/model.pth")
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") args.rank = -1 args.world_size = 1 if args.model_parallel: args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) torch.cuda.set_device(smp.local_rank()) args.rank = smp.dp_rank() args.world_size = smp.size() else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) if deepspeed_utils.is_root_worker(): args.rank = 0 if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed + args.rank) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) cudnn.deterministic = True if cudnn.deterministic: warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True} device = torch.device("cuda") logger.debug(f"args.image_folder : {args.image_folder}") logger.debug(f"args.rank : {args.rank}") ## SageMaker try: if os.environ.get('SM_MODEL_DIR') is not None: args.model_dir = os.environ.get('SM_MODEL_DIR') # args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR') args.image_folder = os.environ.get('SM_CHANNEL_TRAINING') except: logger.debug("not SageMaker") pass IMAGE_SIZE = args.image_size IMAGE_PATH = args.image_folder EPOCHS = args.EPOCHS BATCH_SIZE = args.BATCH_SIZE LEARNING_RATE = args.LEARNING_RATE LR_DECAY_RATE = args.LR_DECAY_RATE NUM_TOKENS = args.NUM_TOKENS NUM_LAYERS = args.NUM_LAYERS NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS EMB_DIM = args.EMB_DIM HID_DIM = args.HID_DIM KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT STARTING_TEMP = args.STARTING_TEMP TEMP_MIN = args.TEMP_MIN ANNEAL_RATE = args.ANNEAL_RATE NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE # transform = Compose( # [ # RandomResizedCrop(args.image_size, args.image_size), # OneOf( # [ # IAAAdditiveGaussianNoise(), # GaussNoise(), # ], # p=0.2 # ), # VerticalFlip(p=0.5), # OneOf( # [ # MotionBlur(p=.2), # MedianBlur(blur_limit=3, p=0.1), # Blur(blur_limit=3, p=0.1), # ], # p=0.2 # ), # OneOf( # [ # CLAHE(clip_limit=2), # IAASharpen(), # IAAEmboss(), # RandomBrightnessContrast(), # ], # p=0.3 # ), # HueSaturationValue(p=0.3), # # Normalize( # # mean=[0.485, 0.456, 0.406], # # std=[0.229, 0.224, 0.225], # # ) # ], # p=1.0 # ) transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize(IMAGE_SIZE), T.CenterCrop(IMAGE_SIZE), T.ToTensor() ]) sampler = None dl = None # data logger.debug(f"IMAGE_PATH : {IMAGE_PATH}") # ds = AlbumentationImageDataset( # IMAGE_PATH, # transform=transform, # args=args # ) ds = ImageFolder( IMAGE_PATH, transform=transform, ) if args.model_parallel and (args.ddp or args.horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } ds = SplitDataset(ds, partitions=partitions_dict) ds.select(f"{smp.dp_rank()}") dl = DataLoader(ds, BATCH_SIZE, shuffle=True, drop_last=args.model_parallel, **args.kwargs) vae_params = dict(image_size=IMAGE_SIZE, num_layers=NUM_LAYERS, num_tokens=NUM_TOKENS, codebook_dim=EMB_DIM, hidden_dim=HID_DIM, num_resnet_blocks=NUM_RESNET_BLOCKS) vae = DiscreteVAE(**vae_params, smooth_l1_loss=SMOOTH_L1_LOSS, kl_div_loss_weight=KL_LOSS_WEIGHT).to(device) # optimizer opt = Adam(vae.parameters(), lr=LEARNING_RATE) sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE) if args.model_parallel: import copy dummy_codebook = copy.deepcopy(vae.codebook) dummy_decoder = copy.deepcopy(vae.decoder) vae = smp.DistributedModel(vae) scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) assert len(ds) > 0, 'folder does not contain any images' if (not args.model_parallel) and args.rank == 0: print(f'{len(ds)} images found for training') # weights & biases experiment tracking # import wandb model_config = dict(num_tokens=NUM_TOKENS, smooth_l1_loss=SMOOTH_L1_LOSS, num_resnet_blocks=NUM_RESNET_BLOCKS, kl_loss_weight=KL_LOSS_WEIGHT) # run = wandb.init( # project = 'dalle_train_vae', # job_type = 'train_model', # config = model_config # ) def save_model(path): if not args.rank == 0: return save_obj = {'hparams': vae_params, 'weights': vae.state_dict()} torch.save(save_obj, path) # distribute with deepspeed if not args.model_parallel: deepspeed_utils.check_batch_size(BATCH_SIZE) deepspeed_config = {'train_batch_size': BATCH_SIZE} (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute( args=args, model=vae, optimizer=opt, model_parameters=vae.parameters(), training_data=ds if args.deepspeed else dl, lr_scheduler=sched, config_params=deepspeed_config, ) try: # Rubik: Define smp.step. Return any tensors needed outside. @smp.step def train_step(vae, images, temp): # logger.debug(f"args.amp : {args.amp}") with autocast(enabled=(args.amp > 0)): loss, recons = vae(images, return_loss=True, return_recons=True, temp=temp) scaled_loss = scaler.scale(loss) if args.amp else loss vae.backward(scaled_loss) # torch.nn.utils.clip_grad_norm_(vae.parameters(), 5) return loss, recons @smp.step def get_codes_step(vae, images, k): images = images[:k] logits = vae.forward(images, return_logits=True) codebook_indices = logits.argmax(dim=1).flatten(1) return codebook_indices def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices): from functools import partial for module in dummy_codebook.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) image_embeds = dummy_codebook.forward(codebook_indices) b, n, d = image_embeds.shape h = w = int(sqrt(n)) image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h=h, w=w) for module in dummy_decoder.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) hard_recons = dummy_decoder.forward(image_embeds) return hard_recons except: pass # starting temperature global_step = 0 temp = STARTING_TEMP for epoch in range(EPOCHS): ## batch_time = util.AverageMeter('Time', ':6.3f') data_time = util.AverageMeter('Data', ':6.3f') losses = util.AverageMeter('Loss', ':.4e') top1 = util.AverageMeter('Acc@1', ':6.2f') top5 = util.AverageMeter('Acc@5', ':6.2f') progress = util.ProgressMeter( len(dl), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) vae.train() start = time.time() for i, (images, _) in enumerate(dl): images = images.to(device, non_blocking=True) opt.zero_grad() if args.model_parallel: loss, recons = train_step(vae, images, temp) # Rubik: Average the loss across microbatches. loss = loss.reduce_mean() recons = recons.reduce_mean() else: loss, recons = distr_vae(images, return_loss=True, return_recons=True, temp=temp) if (not args.model_parallel) and args.deepspeed: # Gradients are automatically zeroed after the step distr_vae.backward(loss) distr_vae.step() elif args.model_parallel: if args.amp: scaler.step(opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(vae.local_parameters())) > 0: opt.step() else: loss.backward() opt.step() logs = {} if i % 10 == 0: if args.rank == 0: # if deepspeed_utils.is_root_worker(): k = NUM_IMAGES_SAVE with torch.no_grad(): if args.model_parallel: model_dict = vae.state_dict() model_dict_updated = {} for key, val in model_dict.items(): if "decoder" in key: key = key.replace("decoder.", "") elif "codebook" in key: key = key.replace("codebook.", "") model_dict_updated[key] = val dummy_decoder.load_state_dict(model_dict_updated, strict=False) dummy_codebook.load_state_dict(model_dict_updated, strict=False) codes = get_codes_step(vae, images, k) codes = codes.reduce_mean().to(torch.long) hard_recons = hard_recons_step( dummy_decoder, dummy_codebook, codes) else: codes = vae.get_codebook_indices(images[:k]) hard_recons = vae.decode(codes) images, recons = map(lambda t: t[:k], (images, recons)) images, recons, hard_recons, codes = map( lambda t: t.detach().cpu(), (images, recons, hard_recons, codes)) images, recons, hard_recons = map( lambda t: make_grid(t.float(), nrow=int(sqrt(k)), normalize=True, range=(-1, 1)), (images, recons, hard_recons)) # logs = { # **logs, # 'sample images': wandb.Image(images, caption = 'original images'), # 'reconstructions': wandb.Image(recons, caption = 'reconstructions'), # 'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'), # 'codebook_indices': wandb.Histogram(codes), # 'temperature': temp # } if args.model_parallel: filename = f'{args.model_dir}/vae.pt' if smp.dp_rank == 0: if args.save_full_model: model_dict = vae.state_dict() opt_dict = opt.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=False, ) else: model_dict = vae.local_state_dict() opt_dict = opt.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=True, ) smp.barrier() else: save_model(f'{args.model_dir}/vae.pt') # wandb.save(f'{args.model_dir}/vae.pt') # temperature anneal temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN) # lr decay sched.step() # Collective loss, averaged if args.model_parallel: avg_loss = loss.detach().clone() # print("args.world_size : {}".format(args.world_size)) avg_loss /= args.world_size else: avg_loss = deepspeed_utils.average_all(loss) if args.rank == 0: if i % 100 == 0: lr = sched.get_last_lr()[0] print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},') logs = { **logs, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item(), 'lr': lr } # wandb.log(logs) global_step += 1 if args.rank == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy # prec1, prec5 = util.accuracy(output, target, topk=(1, 5)) # to_python_float incurs a host<->device sync losses.update(util.to_python_float(loss), images.size(0)) # top1.update(util.to_python_float(prec1), images.size(0)) # top5.update(util.to_python_float(prec5), images.size(0)) # Waiting until finishing operations on GPU (Pytorch default: async) torch.cuda.synchronize() batch_time.update((time.time() - start) / args.log_interval) end = time.time() print( 'Epoch: [{0}][{1}/{2}] ' 'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, ' 'Train_Speed={3:.3f} ({4:.3f}), ' 'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format( epoch, i, len(dl), args.world_size * BATCH_SIZE / batch_time.val, args.world_size * BATCH_SIZE / batch_time.avg, batch_time=batch_time, loss=losses)) # if deepspeed_utils.is_root_worker(): # save trained model to wandb as an artifact every epoch's end # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae.pt') # run.log_artifact(model_artifact) if args.rank == 0: # if deepspeed_utils.is_root_worker(): # save final vae and cleanup if args.model_parallel: logger.debug('save model_parallel') else: save_model(os.path.join(args.model_dir, 'vae-final.pt')) # wandb.save(f'{args.model_dir}/vae-final.pt') # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae-final.pt') # run.log_artifact(model_artifact) # wandb.finish() if args.model_parallel: if args.assert_losses: if args.horovod or args.ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(loss, smp.DP_GROUP) for l in losses: print(l) assert math.isclose(l, losses[0]) assert loss < 0.18 else: assert loss < 0.08 smp.barrier() print("SMP training finished successfully")