def smp_init(model, optimizer, args): model = smp.DistributedModel(model) args.scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) model.load_state_dict(args.checkpoint["model_state_dict"]) optimizer.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) model.load_state_dict(args.checkpoint["model_state_dict"]) optimizer.load_state_dict(args.checkpoint["optimizer_state_dict"]) return model, optimizer, args
def load_and_verify_ckptsum(args, model, optimizer, filename): results = smp.load(filename) optimizer_result = ( results["optimizer"] if not args.shard_optimizer_state else results["optimizer"][smp.rdp_rank()] ) model_result = results["model"] def opt_check_fn(mod, opt): loaded_opt_states = ( opt.orig_state_dict()["state"] if args.shard_optimizer_state else opt.local_state_dict()["state"] ) for param_idx, state in loaded_opt_states.items(): for key, val in state.items(): if isinstance(val, torch.Tensor): assert torch.isclose( torch.sum(val), optimizer_result["tensors"][f"{param_idx}_{key}"] ), f"mismatch for param_idx: {param_idx}, key is {key}" else: assert ( val == optimizer_result["scalars"][f"{param_idx}_{key}"] ), f"mismatch for param_idx: {param_idx}, key is {key}" print("Optimizer save/load check passed successfully") def model_check_fn(mod, opt): for param_name, param in mod.local_state_dict().items(): if isinstance(param, torch.Tensor): assert torch.isclose( torch.sum(param), model_result["tensors"][param_name] ), f"mismatch for param_name: {param_name}" else: assert ( param == model_result["scalars"][param_name] ), f"mismatch for param_name: {param_name}" print("Model save/load check passed successfully") model.register_post_partition_hook(model_check_fn) model.register_post_step_hook(opt_check_fn)
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = args.ddp > 0 use_horovod = args.horovod > 0 # Fix seeds in order to get the same losses across runs random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) cfg = { "microbatches": args.num_microbatches, "placement_strategy": "spread", "pipeline": args.pipeline, "optimize": "speed", "partitions": args.num_partitions, "horovod": use_horovod, "ddp": use_ddp, } smp.init(cfg) # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": args.batch_size} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) if args.partial_checkpoint: checkpoint = smp.load(args.partial_checkpoint, partial=True) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: checkpoint = smp.load(args.full_checkpoint, partial=False) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): train(args, model, scaler, device, train_loader, optimizer, epoch) test_loss = test(args, model, device, test_loader) scheduler.step() if args.save_partial_model: if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"./pt_mnist_checkpoint.pt", partial=True, ) if args.save_full_model: if smp.dp_rank() == 0: model_dict = model.state_dict() opt_dict = optimizer.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, "./pt_mnist_checkpoint.pt", partial=False, ) # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if args.assert_losses: if use_horovod or use_ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(test_loss, smp.DP_GROUP) for l in losses: assert math.isclose(l, losses[0]) assert test_loss < 0.18 else: assert test_loss < 0.08
def prepare_model_and_optimizer(args, device): # Prepare model config = modeling.BertConfig.from_json_file(args.config_file) # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) if args.use_sequential > 0: config.use_sequential = True else: config.use_sequential = False modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training model = modeling.BertForPreTraining(config) model.checkpoint_activations(args.checkpoint_activations) if args.smp > 0: # SMP: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) checkpoint = None if not args.resume_from_checkpoint: global_step = 0 else: if not args.init_checkpoint: if not args.s3_checkpoint_uri: raise ValueError( "Need to set s3_checkpoint_uri, if init_checkpoint not set" ) if smp.local_rank() == 0: sync_s3_checkpoints_to_local(args.output_dir, args.s3_checkpoint_uri) smp.barrier() if args.resume_step == -1 and not args.init_checkpoint: model_names = [ f for f in os.listdir(args.output_dir) if ".pt" in f ] args.resume_step = max([ int(x.split(".pt")[0].split("_")[1].strip()) for x in model_names ]) global_step = args.resume_step if not args.init_checkpoint else 0 # SMP: Load a model that was saved with smp.save if not args.init_checkpoint: checkpoint = smp.load( os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), partial=args.partial_checkpoint, ) else: checkpoint = smp.load(args.init_checkpoint) model.load_state_dict(checkpoint["model"], strict=False) if args.phase2 and not args.init_checkpoint: global_step -= args.phase1_end_step if is_main_process(): print("resume step from ", args.resume_step) model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ["bias", "gamma", "beta", "LayerNorm"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate) if args.smp > 0: # SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model # Also provides APIs to obtain local optimizer state for the current mp_rank. optimizer = smp.DistributedOptimizer(optimizer) lr_scheduler = PolyWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) if args.fp16: if args.loss_scale == 0: model, optimizer = amp.initialize( model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16, ) else: model, optimizer = amp.initialize( model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16, ) amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale if args.resume_from_checkpoint: if args.phase2 or args.init_checkpoint: keys = list(checkpoint["optimizer"]["state"].keys()) # Override hyperparameters from previous checkpoint for key in keys: checkpoint["optimizer"]["state"][key]["step"] = global_step for iter, item in enumerate( checkpoint["optimizer"]["param_groups"]): checkpoint["optimizer"]["param_groups"][iter][ "step"] = global_step checkpoint["optimizer"]["param_groups"][iter][ "t_total"] = args.max_steps checkpoint["optimizer"]["param_groups"][iter][ "warmup"] = args.warmup_proportion checkpoint["optimizer"]["param_groups"][iter][ "lr"] = args.learning_rate optimizer.load_state_dict(checkpoint["optimizer"]) # , strict=False) # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint["optimizer"]) for param, saved_param in zip(amp.master_params(optimizer), checkpoint["master params"]): param.data.copy_(saved_param.data) # if args.local_rank != -1: # if not args.allreduce_post_accumulation: # model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size()) # else: # flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) ) # elif args.n_gpu > 1: # model = torch.nn.DataParallel(model) criterion = BertPretrainingCriterion(config.vocab_size) return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
def load_model_and_optimizer( output_dir, model, optimizer, lr_scheduler, partial, args, translate_from_hf=False, seq_length=1024, load_model=True, load_optimizer=True, num_params=0, ): # Find longest-trained checkpoint re_pattern = f"trained_gpt_nparams-{num_params}_steps-(?P<total_steps>\d+)\.pt" if partial: re_pattern += "_(?P<rank>\d+)" else: re_pattern += "$" ckpt_paths = sorted( [(int(re.match(re_pattern, p).group("total_steps")), os.path.join(output_dir, p)) for p in os.listdir(output_dir) if re.match(re_pattern, p)], reverse=True, ) if not ckpt_paths: raise Exception( f'No checkpoints could be found in "{output_dir}". Candidates: {os.listdir(output_dir)}' ) local_ckpt_path = ckpt_paths[0][1] if partial: # need to pass prefix without ranks to smp local_ckpt_path = local_ckpt_path.split(".pt")[0] + ".pt" if args.gather_if_shard > 0: # Should expect v2 checkpoint here checkpoint = smp.load(local_ckpt_path, partial=partial) else: # Loading separately for model and opt checkpoint = torch.load( f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_0") if smp.rdp_rank() != 0: opt_checkpoint = torch.load( f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_{smp.rdp_rank()}" ) if load_model: checkpointed_model = (translate_hf_state_dict_to_smdistributed( checkpoint["model"], seq_length) if translate_from_hf else checkpoint["model"]) model.load_state_dict(checkpointed_model, same_partition_load=args.same_partition_load > 0) if lr_scheduler is not None: lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) if load_optimizer: def opt_load_hook(mod, opt): load_fn = load_fp16_optimizer checkpoint = (checkpoint if args.gather_if_shard > 0 or smp.rdp_rank() == 0 else opt_checkpoint) if args.smp_version < 110: if not partial and args.skip_full_optimizer: print( "Skipping loading the final optimizer state, and reloading master_params from model_params" ) opt.reload_model_params() else: load_fn(args, mod, opt, checkpoint, partial=partial) model.register_post_step_hook(opt_load_hook) elif not partial and args.skip_full_optimizer: print( "Skipping loading the final optimizer state, and reloading master_params from model_params for fp16" ) if args.fp16: model.register_post_step_hook(opt.reload_model_params) else: optimizer.load_optimizer_backcompat(checkpoint["optimizer"], args.gather_if_shard) print(f'Loaded model from "{local_ckpt_path}"') batch_idx = 0 if "batch_idx" in checkpoint: batch_idx = checkpoint["batch_idx"] return ( model, optimizer, checkpoint["total_steps"], checkpoint["curr_train_path_index"], batch_idx, )
def load_model_and_optimizer( output_dir, model, optimizer, lr_scheduler, partial, args, translate_from_hf=False, seq_length=1024, load_model=True, load_optimizer=True, num_params=0, ): # Find longest-trained checkpoint re_pattern = f"trained_gpt_nparams-{num_params}_steps-(?P<total_steps>\d+)\.pt" if partial: re_pattern += "_(?P<rank>\d+)" else: re_pattern += "$" ckpt_paths = sorted( [ (int(re.match(re_pattern, p).group("total_steps")), os.path.join(output_dir, p)) for p in os.listdir(output_dir) if re.match(re_pattern, p) ], reverse=True, ) if not ckpt_paths: raise Exception( f'No checkpoints could be found in "{output_dir}". Candidates: {os.listdir(output_dir)}' ) local_ckpt_path = ckpt_paths[0][1] if partial: # need to pass prefix without ranks to smp local_ckpt_path = local_ckpt_path.split(".pt")[0] + ".pt" checkpoint = smp.load(local_ckpt_path, partial=partial) if load_model: checkpointed_model = ( translate_hf_state_dict_to_smdistributed(checkpoint["model"], seq_length) if translate_from_hf else checkpoint["model"] ) model.load_state_dict(checkpointed_model, same_partition_load=args.same_partition_load > 0) if lr_scheduler is not None: lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) if load_optimizer: # Loading loss scale eagerly if not args.megatron: opt_state_dict = checkpoint["optimizer"] optimizer.loss_scaler = opt_state_dict["loss_scaler"] optimizer.loss_scaler.model = model optimizer.dynamic_loss_scale = opt_state_dict["dynamic_loss_scale"] optimizer.overflow = opt_state_dict["overflow"] optimizer.first_closure_call_this_step = opt_state_dict["first_closure_call_this_step"] def opt_load_hook(mod, opt): load_fn = load_fp16_optimizer_megatron if args.megatron else load_fp16_optimizer if args.fp16: if not partial and args.skip_full_optimizer: print( "Skipping loading the final optimizer state, and reloading master_params from model_params" ) opt.reload_model_params() else: load_fn(args, mod, opt, checkpoint, partial=partial) else: # fp32 if not partial and args.skip_full_optimizer: print("Skipping loading the final optimizer state") else: opt.load_state_dict(checkpoint["optimizer"]) model.register_post_step_hook(opt_load_hook) print(f'Loaded model from "{local_ckpt_path}"') batch_idx = 0 if "batch_idx" in checkpoint: batch_idx = checkpoint["batch_idx"] return model, optimizer, checkpoint["total_steps"], checkpoint["curr_train_path_index"], batch_idx
logger.debug(f"args.local_rank : {args.local_rank}") if args.local_rank is not None: torch.cuda.set_device(args.local_rank) else: torch.cuda.set_device(0) if args.multigpus_distributed: vae.cuda(args.local_rank) if args.model_parallel: vae = smp.DistributedModel(vae) args.scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) else: vae = vae.cuda() else: vae = vae.cuda() assert len(ds) > 0, 'folder does not contain any images' if (not args.model_parallel) and deepspeed_utils.is_root_worker(): print(f'{len(ds)} images found for training')
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") args.rank = -1 args.world_size = 1 if args.model_parallel: args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) torch.cuda.set_device(smp.local_rank()) args.rank = smp.dp_rank() args.world_size = smp.size() else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) if deepspeed_utils.is_root_worker(): args.rank = 0 if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed + args.rank) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) cudnn.deterministic = True if cudnn.deterministic: warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True} device = torch.device("cuda") logger.debug(f"args.image_folder : {args.image_folder}") logger.debug(f"args.rank : {args.rank}") ## SageMaker try: if os.environ.get('SM_MODEL_DIR') is not None: args.model_dir = os.environ.get('SM_MODEL_DIR') # args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR') args.image_folder = os.environ.get('SM_CHANNEL_TRAINING') except: logger.debug("not SageMaker") pass IMAGE_SIZE = args.image_size IMAGE_PATH = args.image_folder EPOCHS = args.EPOCHS BATCH_SIZE = args.BATCH_SIZE LEARNING_RATE = args.LEARNING_RATE LR_DECAY_RATE = args.LR_DECAY_RATE NUM_TOKENS = args.NUM_TOKENS NUM_LAYERS = args.NUM_LAYERS NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS EMB_DIM = args.EMB_DIM HID_DIM = args.HID_DIM KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT STARTING_TEMP = args.STARTING_TEMP TEMP_MIN = args.TEMP_MIN ANNEAL_RATE = args.ANNEAL_RATE NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE # transform = Compose( # [ # RandomResizedCrop(args.image_size, args.image_size), # OneOf( # [ # IAAAdditiveGaussianNoise(), # GaussNoise(), # ], # p=0.2 # ), # VerticalFlip(p=0.5), # OneOf( # [ # MotionBlur(p=.2), # MedianBlur(blur_limit=3, p=0.1), # Blur(blur_limit=3, p=0.1), # ], # p=0.2 # ), # OneOf( # [ # CLAHE(clip_limit=2), # IAASharpen(), # IAAEmboss(), # RandomBrightnessContrast(), # ], # p=0.3 # ), # HueSaturationValue(p=0.3), # # Normalize( # # mean=[0.485, 0.456, 0.406], # # std=[0.229, 0.224, 0.225], # # ) # ], # p=1.0 # ) transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize(IMAGE_SIZE), T.CenterCrop(IMAGE_SIZE), T.ToTensor() ]) sampler = None dl = None # data logger.debug(f"IMAGE_PATH : {IMAGE_PATH}") # ds = AlbumentationImageDataset( # IMAGE_PATH, # transform=transform, # args=args # ) ds = ImageFolder( IMAGE_PATH, transform=transform, ) if args.model_parallel and (args.ddp or args.horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } ds = SplitDataset(ds, partitions=partitions_dict) ds.select(f"{smp.dp_rank()}") dl = DataLoader(ds, BATCH_SIZE, shuffle=True, drop_last=args.model_parallel, **args.kwargs) vae_params = dict(image_size=IMAGE_SIZE, num_layers=NUM_LAYERS, num_tokens=NUM_TOKENS, codebook_dim=EMB_DIM, hidden_dim=HID_DIM, num_resnet_blocks=NUM_RESNET_BLOCKS) vae = DiscreteVAE(**vae_params, smooth_l1_loss=SMOOTH_L1_LOSS, kl_div_loss_weight=KL_LOSS_WEIGHT).to(device) # optimizer opt = Adam(vae.parameters(), lr=LEARNING_RATE) sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE) if args.model_parallel: import copy dummy_codebook = copy.deepcopy(vae.codebook) dummy_decoder = copy.deepcopy(vae.decoder) vae = smp.DistributedModel(vae) scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) assert len(ds) > 0, 'folder does not contain any images' if (not args.model_parallel) and args.rank == 0: print(f'{len(ds)} images found for training') # weights & biases experiment tracking # import wandb model_config = dict(num_tokens=NUM_TOKENS, smooth_l1_loss=SMOOTH_L1_LOSS, num_resnet_blocks=NUM_RESNET_BLOCKS, kl_loss_weight=KL_LOSS_WEIGHT) # run = wandb.init( # project = 'dalle_train_vae', # job_type = 'train_model', # config = model_config # ) def save_model(path): if not args.rank == 0: return save_obj = {'hparams': vae_params, 'weights': vae.state_dict()} torch.save(save_obj, path) # distribute with deepspeed if not args.model_parallel: deepspeed_utils.check_batch_size(BATCH_SIZE) deepspeed_config = {'train_batch_size': BATCH_SIZE} (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute( args=args, model=vae, optimizer=opt, model_parameters=vae.parameters(), training_data=ds if args.deepspeed else dl, lr_scheduler=sched, config_params=deepspeed_config, ) try: # Rubik: Define smp.step. Return any tensors needed outside. @smp.step def train_step(vae, images, temp): # logger.debug(f"args.amp : {args.amp}") with autocast(enabled=(args.amp > 0)): loss, recons = vae(images, return_loss=True, return_recons=True, temp=temp) scaled_loss = scaler.scale(loss) if args.amp else loss vae.backward(scaled_loss) # torch.nn.utils.clip_grad_norm_(vae.parameters(), 5) return loss, recons @smp.step def get_codes_step(vae, images, k): images = images[:k] logits = vae.forward(images, return_logits=True) codebook_indices = logits.argmax(dim=1).flatten(1) return codebook_indices def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices): from functools import partial for module in dummy_codebook.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) image_embeds = dummy_codebook.forward(codebook_indices) b, n, d = image_embeds.shape h = w = int(sqrt(n)) image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h=h, w=w) for module in dummy_decoder.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) hard_recons = dummy_decoder.forward(image_embeds) return hard_recons except: pass # starting temperature global_step = 0 temp = STARTING_TEMP for epoch in range(EPOCHS): ## batch_time = util.AverageMeter('Time', ':6.3f') data_time = util.AverageMeter('Data', ':6.3f') losses = util.AverageMeter('Loss', ':.4e') top1 = util.AverageMeter('Acc@1', ':6.2f') top5 = util.AverageMeter('Acc@5', ':6.2f') progress = util.ProgressMeter( len(dl), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) vae.train() start = time.time() for i, (images, _) in enumerate(dl): images = images.to(device, non_blocking=True) opt.zero_grad() if args.model_parallel: loss, recons = train_step(vae, images, temp) # Rubik: Average the loss across microbatches. loss = loss.reduce_mean() recons = recons.reduce_mean() else: loss, recons = distr_vae(images, return_loss=True, return_recons=True, temp=temp) if (not args.model_parallel) and args.deepspeed: # Gradients are automatically zeroed after the step distr_vae.backward(loss) distr_vae.step() elif args.model_parallel: if args.amp: scaler.step(opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(vae.local_parameters())) > 0: opt.step() else: loss.backward() opt.step() logs = {} if i % 10 == 0: if args.rank == 0: # if deepspeed_utils.is_root_worker(): k = NUM_IMAGES_SAVE with torch.no_grad(): if args.model_parallel: model_dict = vae.state_dict() model_dict_updated = {} for key, val in model_dict.items(): if "decoder" in key: key = key.replace("decoder.", "") elif "codebook" in key: key = key.replace("codebook.", "") model_dict_updated[key] = val dummy_decoder.load_state_dict(model_dict_updated, strict=False) dummy_codebook.load_state_dict(model_dict_updated, strict=False) codes = get_codes_step(vae, images, k) codes = codes.reduce_mean().to(torch.long) hard_recons = hard_recons_step( dummy_decoder, dummy_codebook, codes) else: codes = vae.get_codebook_indices(images[:k]) hard_recons = vae.decode(codes) images, recons = map(lambda t: t[:k], (images, recons)) images, recons, hard_recons, codes = map( lambda t: t.detach().cpu(), (images, recons, hard_recons, codes)) images, recons, hard_recons = map( lambda t: make_grid(t.float(), nrow=int(sqrt(k)), normalize=True, range=(-1, 1)), (images, recons, hard_recons)) # logs = { # **logs, # 'sample images': wandb.Image(images, caption = 'original images'), # 'reconstructions': wandb.Image(recons, caption = 'reconstructions'), # 'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'), # 'codebook_indices': wandb.Histogram(codes), # 'temperature': temp # } if args.model_parallel: filename = f'{args.model_dir}/vae.pt' if smp.dp_rank == 0: if args.save_full_model: model_dict = vae.state_dict() opt_dict = opt.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=False, ) else: model_dict = vae.local_state_dict() opt_dict = opt.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=True, ) smp.barrier() else: save_model(f'{args.model_dir}/vae.pt') # wandb.save(f'{args.model_dir}/vae.pt') # temperature anneal temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN) # lr decay sched.step() # Collective loss, averaged if args.model_parallel: avg_loss = loss.detach().clone() # print("args.world_size : {}".format(args.world_size)) avg_loss /= args.world_size else: avg_loss = deepspeed_utils.average_all(loss) if args.rank == 0: if i % 100 == 0: lr = sched.get_last_lr()[0] print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},') logs = { **logs, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item(), 'lr': lr } # wandb.log(logs) global_step += 1 if args.rank == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy # prec1, prec5 = util.accuracy(output, target, topk=(1, 5)) # to_python_float incurs a host<->device sync losses.update(util.to_python_float(loss), images.size(0)) # top1.update(util.to_python_float(prec1), images.size(0)) # top5.update(util.to_python_float(prec5), images.size(0)) # Waiting until finishing operations on GPU (Pytorch default: async) torch.cuda.synchronize() batch_time.update((time.time() - start) / args.log_interval) end = time.time() print( 'Epoch: [{0}][{1}/{2}] ' 'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, ' 'Train_Speed={3:.3f} ({4:.3f}), ' 'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format( epoch, i, len(dl), args.world_size * BATCH_SIZE / batch_time.val, args.world_size * BATCH_SIZE / batch_time.avg, batch_time=batch_time, loss=losses)) # if deepspeed_utils.is_root_worker(): # save trained model to wandb as an artifact every epoch's end # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae.pt') # run.log_artifact(model_artifact) if args.rank == 0: # if deepspeed_utils.is_root_worker(): # save final vae and cleanup if args.model_parallel: logger.debug('save model_parallel') else: save_model(os.path.join(args.model_dir, 'vae-final.pt')) # wandb.save(f'{args.model_dir}/vae-final.pt') # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae-final.pt') # run.log_artifact(model_artifact) # wandb.finish() if args.model_parallel: if args.assert_losses: if args.horovod or args.ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(loss, smp.DP_GROUP) for l in losses: print(l) assert math.isclose(l, losses[0]) assert loss < 0.18 else: assert loss < 0.08 smp.barrier() print("SMP training finished successfully")