Beispiel #1
0
def smp_init(model, optimizer, args):
    model = smp.DistributedModel(model)
    args.scaler = smp.amp.GradScaler()
    optimizer = smp.DistributedOptimizer(optimizer)
    if args.partial_checkpoint:
        args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
        model.load_state_dict(args.checkpoint["model_state_dict"])
        optimizer.load_state_dict(args.checkpoint["optimizer_state_dict"])
    elif args.full_checkpoint:
        args.checkpoint = smp.load(args.full_checkpoint, partial=False)
        model.load_state_dict(args.checkpoint["model_state_dict"])
        optimizer.load_state_dict(args.checkpoint["optimizer_state_dict"])

    return model, optimizer, args
def load_and_verify_ckptsum(args, model, optimizer, filename):
    results = smp.load(filename)
    optimizer_result = (
        results["optimizer"]
        if not args.shard_optimizer_state
        else results["optimizer"][smp.rdp_rank()]
    )
    model_result = results["model"]

    def opt_check_fn(mod, opt):
        loaded_opt_states = (
            opt.orig_state_dict()["state"]
            if args.shard_optimizer_state
            else opt.local_state_dict()["state"]
        )
        for param_idx, state in loaded_opt_states.items():
            for key, val in state.items():
                if isinstance(val, torch.Tensor):
                    assert torch.isclose(
                        torch.sum(val), optimizer_result["tensors"][f"{param_idx}_{key}"]
                    ), f"mismatch for param_idx: {param_idx}, key is {key}"
                else:
                    assert (
                        val == optimizer_result["scalars"][f"{param_idx}_{key}"]
                    ), f"mismatch for param_idx: {param_idx}, key is {key}"
        print("Optimizer save/load check passed successfully")

    def model_check_fn(mod, opt):
        for param_name, param in mod.local_state_dict().items():
            if isinstance(param, torch.Tensor):
                assert torch.isclose(
                    torch.sum(param), model_result["tensors"][param_name]
                ), f"mismatch for param_name: {param_name}"
            else:
                assert (
                    param == model_result["scalars"][param_name]
                ), f"mismatch for param_name: {param_name}"
        print("Model save/load check passed successfully")

    model.register_post_partition_hook(model_check_fn)
    model.register_post_step_hook(opt_check_fn)
Beispiel #3
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")
    use_ddp = args.ddp > 0
    use_horovod = args.horovod > 0

    # Fix seeds in order to get the same losses across runs
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    cfg = {
        "microbatches": args.num_microbatches,
        "placement_strategy": "spread",
        "pipeline": args.pipeline,
        "optimize": "speed",
        "partitions": args.num_partitions,
        "horovod": use_horovod,
        "ddp": use_ddp,
    }

    smp.init(cfg)

    # SM Distributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")
    kwargs = {"batch_size": args.batch_size}
    kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False})

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # SM Distributed: Download only on a single process per instance.
    # When this is not present, the file is corrupted by multiple processes trying
    # to download and extract at the same time
    if smp.local_rank() == 0:
        dataset1 = datasets.MNIST("../data",
                                  train=True,
                                  download=True,
                                  transform=transform)
    smp.barrier()
    dataset1 = datasets.MNIST("../data",
                              train=True,
                              download=False,
                              transform=transform)

    if (use_ddp or use_horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        dataset1 = SplitDataset(dataset1, partitions=partitions_dict)
        dataset1.select(f"{smp.dp_rank()}")

    # Download and create dataloaders for train and test dataset
    dataset2 = datasets.MNIST("../data", train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    model = GroupedNet()

    # SMP handles the transfer of parameters to the right device
    # and the user doesn't need to call 'model.to' explicitly.
    # model.to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    # SM Distributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    model = smp.DistributedModel(model)
    scaler = smp.amp.GradScaler()
    optimizer = smp.DistributedOptimizer(optimizer)

    if args.partial_checkpoint:
        checkpoint = smp.load(args.partial_checkpoint, partial=True)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    elif args.full_checkpoint:
        checkpoint = smp.load(args.full_checkpoint, partial=False)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, scaler, device, train_loader, optimizer, epoch)
        test_loss = test(args, model, device, test_loader)
        scheduler.step()

    if args.save_partial_model:
        if smp.dp_rank() == 0:
            model_dict = model.local_state_dict()
            opt_dict = optimizer.local_state_dict()
            smp.save(
                {
                    "model_state_dict": model_dict,
                    "optimizer_state_dict": opt_dict
                },
                f"./pt_mnist_checkpoint.pt",
                partial=True,
            )

    if args.save_full_model:
        if smp.dp_rank() == 0:
            model_dict = model.state_dict()
            opt_dict = optimizer.state_dict()
            smp.save(
                {
                    "model_state_dict": model_dict,
                    "optimizer_state_dict": opt_dict
                },
                "./pt_mnist_checkpoint.pt",
                partial=False,
            )

    # Waiting the save checkpoint to be finished before run another allgather_object
    smp.barrier()

    if args.assert_losses:
        if use_horovod or use_ddp:
            # SM Distributed: If using data parallelism, gather all losses across different model
            # replicas and check if losses match.

            losses = smp.allgather(test_loss, smp.DP_GROUP)
            for l in losses:
                assert math.isclose(l, losses[0])

            assert test_loss < 0.18
        else:
            assert test_loss < 0.08
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    if args.use_sequential > 0:
        config.use_sequential = True
    else:
        config.use_sequential = False

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)
    model.checkpoint_activations(args.checkpoint_activations)
    if args.smp > 0:
        # SMP: Use the DistributedModel container to provide the model
        # to be partitioned across different ranks. For the rest of the script,
        # the returned DistributedModel object should be used in place of
        # the model provided for DistributedModel class instantiation.
        model = smp.DistributedModel(model)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if not args.init_checkpoint:
            if not args.s3_checkpoint_uri:
                raise ValueError(
                    "Need to set s3_checkpoint_uri, if init_checkpoint not set"
                )
            if smp.local_rank() == 0:
                sync_s3_checkpoints_to_local(args.output_dir,
                                             args.s3_checkpoint_uri)
            smp.barrier()
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if ".pt" in f
            ]
            args.resume_step = max([
                int(x.split(".pt")[0].split("_")[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        # SMP: Load a model that was saved with smp.save
        if not args.init_checkpoint:
            checkpoint = smp.load(
                os.path.join(args.output_dir,
                             "ckpt_{}.pt".format(global_step)),
                partial=args.partial_checkpoint,
            )
        else:
            checkpoint = smp.load(args.init_checkpoint)

        model.load_state_dict(checkpoint["model"], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "gamma", "beta", "LayerNorm"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01,
        },
        {
            "params":
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0,
        },
    ]

    optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)
    if args.smp > 0:
        # SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model
        # Also provides APIs to obtain local optimizer state for the current mp_rank.
        optimizer = smp.DistributedOptimizer(optimizer)
    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps)

    if args.fp16:
        if args.loss_scale == 0:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level="O2",
                loss_scale="dynamic",
                cast_model_outputs=torch.float16,
            )
        else:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level="O2",
                loss_scale=args.loss_scale,
                cast_model_outputs=torch.float16,
            )
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint["optimizer"]["state"].keys())
            # Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint["optimizer"]["state"][key]["step"] = global_step
            for iter, item in enumerate(
                    checkpoint["optimizer"]["param_groups"]):
                checkpoint["optimizer"]["param_groups"][iter][
                    "step"] = global_step
                checkpoint["optimizer"]["param_groups"][iter][
                    "t_total"] = args.max_steps
                checkpoint["optimizer"]["param_groups"][iter][
                    "warmup"] = args.warmup_proportion
                checkpoint["optimizer"]["param_groups"][iter][
                    "lr"] = args.learning_rate
        optimizer.load_state_dict(checkpoint["optimizer"])  # , strict=False)
        # Restore AMP master parameters
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint["optimizer"])
            for param, saved_param in zip(amp.master_params(optimizer),
                                          checkpoint["master params"]):
                param.data.copy_(saved_param.data)

    # if args.local_rank != -1:
    #    if not args.allreduce_post_accumulation:
    #        model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
    #    else:
    #        flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
    # elif args.n_gpu > 1:
    #    model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
Beispiel #5
0
def load_model_and_optimizer(
    output_dir,
    model,
    optimizer,
    lr_scheduler,
    partial,
    args,
    translate_from_hf=False,
    seq_length=1024,
    load_model=True,
    load_optimizer=True,
    num_params=0,
):
    # Find longest-trained checkpoint
    re_pattern = f"trained_gpt_nparams-{num_params}_steps-(?P<total_steps>\d+)\.pt"
    if partial:
        re_pattern += "_(?P<rank>\d+)"
    else:
        re_pattern += "$"

    ckpt_paths = sorted(
        [(int(re.match(re_pattern,
                       p).group("total_steps")), os.path.join(output_dir, p))
         for p in os.listdir(output_dir) if re.match(re_pattern, p)],
        reverse=True,
    )
    if not ckpt_paths:
        raise Exception(
            f'No checkpoints could be found in "{output_dir}".  Candidates: {os.listdir(output_dir)}'
        )

    local_ckpt_path = ckpt_paths[0][1]

    if partial:
        # need to pass prefix without ranks to smp
        local_ckpt_path = local_ckpt_path.split(".pt")[0] + ".pt"

    if args.gather_if_shard > 0:
        # Should expect v2 checkpoint here
        checkpoint = smp.load(local_ckpt_path, partial=partial)
    else:
        # Loading separately for model and opt
        checkpoint = torch.load(
            f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_0")
        if smp.rdp_rank() != 0:
            opt_checkpoint = torch.load(
                f"{local_ckpt_path}_{smp.pp_rank()}_{smp.tp_rank()}_{smp.rdp_rank()}"
            )

    if load_model:
        checkpointed_model = (translate_hf_state_dict_to_smdistributed(
            checkpoint["model"], seq_length)
                              if translate_from_hf else checkpoint["model"])
        model.load_state_dict(checkpointed_model,
                              same_partition_load=args.same_partition_load > 0)
        if lr_scheduler is not None:
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

    if load_optimizer:

        def opt_load_hook(mod, opt):
            load_fn = load_fp16_optimizer

        checkpoint = (checkpoint if args.gather_if_shard > 0
                      or smp.rdp_rank() == 0 else opt_checkpoint)
        if args.smp_version < 110:
            if not partial and args.skip_full_optimizer:
                print(
                    "Skipping loading the final optimizer state, and reloading master_params from model_params"
                )
                opt.reload_model_params()
            else:
                load_fn(args, mod, opt, checkpoint, partial=partial)
            model.register_post_step_hook(opt_load_hook)
        elif not partial and args.skip_full_optimizer:
            print(
                "Skipping loading the final optimizer state, and reloading master_params from model_params for fp16"
            )
            if args.fp16:
                model.register_post_step_hook(opt.reload_model_params)
        else:
            optimizer.load_optimizer_backcompat(checkpoint["optimizer"],
                                                args.gather_if_shard)

    print(f'Loaded model from "{local_ckpt_path}"')

    batch_idx = 0
    if "batch_idx" in checkpoint:
        batch_idx = checkpoint["batch_idx"]

    return (
        model,
        optimizer,
        checkpoint["total_steps"],
        checkpoint["curr_train_path_index"],
        batch_idx,
    )
def load_model_and_optimizer(
    output_dir,
    model,
    optimizer,
    lr_scheduler,
    partial,
    args,
    translate_from_hf=False,
    seq_length=1024,
    load_model=True,
    load_optimizer=True,
    num_params=0,
):
    # Find longest-trained checkpoint
    re_pattern = f"trained_gpt_nparams-{num_params}_steps-(?P<total_steps>\d+)\.pt"
    if partial:
        re_pattern += "_(?P<rank>\d+)"
    else:
        re_pattern += "$"

    ckpt_paths = sorted(
        [
            (int(re.match(re_pattern, p).group("total_steps")), os.path.join(output_dir, p))
            for p in os.listdir(output_dir)
            if re.match(re_pattern, p)
        ],
        reverse=True,
    )
    if not ckpt_paths:
        raise Exception(
            f'No checkpoints could be found in "{output_dir}".  Candidates: {os.listdir(output_dir)}'
        )

    local_ckpt_path = ckpt_paths[0][1]

    if partial:
        # need to pass prefix without ranks to smp
        local_ckpt_path = local_ckpt_path.split(".pt")[0] + ".pt"

    checkpoint = smp.load(local_ckpt_path, partial=partial)

    if load_model:
        checkpointed_model = (
            translate_hf_state_dict_to_smdistributed(checkpoint["model"], seq_length)
            if translate_from_hf
            else checkpoint["model"]
        )
        model.load_state_dict(checkpointed_model, same_partition_load=args.same_partition_load > 0)
        if lr_scheduler is not None:
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

    if load_optimizer:
        # Loading loss scale eagerly
        if not args.megatron:
            opt_state_dict = checkpoint["optimizer"]
            optimizer.loss_scaler = opt_state_dict["loss_scaler"]
            optimizer.loss_scaler.model = model
            optimizer.dynamic_loss_scale = opt_state_dict["dynamic_loss_scale"]
            optimizer.overflow = opt_state_dict["overflow"]
            optimizer.first_closure_call_this_step = opt_state_dict["first_closure_call_this_step"]

        def opt_load_hook(mod, opt):
            load_fn = load_fp16_optimizer_megatron if args.megatron else load_fp16_optimizer
            if args.fp16:
                if not partial and args.skip_full_optimizer:
                    print(
                        "Skipping loading the final optimizer state, and reloading master_params from model_params"
                    )
                    opt.reload_model_params()
                else:
                    load_fn(args, mod, opt, checkpoint, partial=partial)
            else:
                # fp32
                if not partial and args.skip_full_optimizer:
                    print("Skipping loading the final optimizer state")
                else:
                    opt.load_state_dict(checkpoint["optimizer"])

        model.register_post_step_hook(opt_load_hook)

    print(f'Loaded model from "{local_ckpt_path}"')

    batch_idx = 0
    if "batch_idx" in checkpoint:
        batch_idx = checkpoint["batch_idx"]

    return model, optimizer, checkpoint["total_steps"], checkpoint["curr_train_path_index"], batch_idx
Beispiel #7
0
    
    logger.debug(f"args.local_rank : {args.local_rank}")
    if args.local_rank is not None:
        torch.cuda.set_device(args.local_rank)
    else:
        torch.cuda.set_device(0)
    
    if args.multigpus_distributed:
        vae.cuda(args.local_rank)
        
        if args.model_parallel:
            vae = smp.DistributedModel(vae)
            args.scaler = smp.amp.GradScaler()
            opt = smp.DistributedOptimizer(opt)
            if args.partial_checkpoint:
                args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
                vae.load_state_dict(args.checkpoint["model_state_dict"])
                opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
            elif args.full_checkpoint:
                args.checkpoint = smp.load(args.full_checkpoint, partial=False)
                vae.load_state_dict(args.checkpoint["model_state_dict"])
                opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
        else:
            
            vae = vae.cuda()
    else:
        vae = vae.cuda()

    assert len(ds) > 0, 'folder does not contain any images'
    if (not args.model_parallel) and deepspeed_utils.is_root_worker():
        print(f'{len(ds)} images found for training')
Beispiel #8
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")

    args.rank = -1
    args.world_size = 1

    if args.model_parallel:
        args.deepspeed = False
        cfg = {
            "microbatches": args.num_microbatches,
            "placement_strategy": args.placement_strategy,
            "pipeline": args.pipeline,
            "optimize": args.optimize,
            "partitions": args.num_partitions,
            "horovod": args.horovod,
            "ddp": args.ddp,
        }

        smp.init(cfg)
        torch.cuda.set_device(smp.local_rank())
        args.rank = smp.dp_rank()
        args.world_size = smp.size()
    else:
        # initialize deepspeed
        print(f"args.deepspeed : {args.deepspeed}")
        deepspeed_utils.init_deepspeed(args.deepspeed)
        if deepspeed_utils.is_root_worker():
            args.rank = 0

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed + args.rank)
        np.random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size)

    cudnn.deterministic = True

    if cudnn.deterministic:
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True}

    device = torch.device("cuda")

    logger.debug(f"args.image_folder : {args.image_folder}")
    logger.debug(f"args.rank : {args.rank}")

    ## SageMaker
    try:
        if os.environ.get('SM_MODEL_DIR') is not None:
            args.model_dir = os.environ.get('SM_MODEL_DIR')
            #             args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR')
            args.image_folder = os.environ.get('SM_CHANNEL_TRAINING')
    except:
        logger.debug("not SageMaker")
        pass

    IMAGE_SIZE = args.image_size
    IMAGE_PATH = args.image_folder

    EPOCHS = args.EPOCHS
    BATCH_SIZE = args.BATCH_SIZE
    LEARNING_RATE = args.LEARNING_RATE
    LR_DECAY_RATE = args.LR_DECAY_RATE

    NUM_TOKENS = args.NUM_TOKENS
    NUM_LAYERS = args.NUM_LAYERS
    NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS
    SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS
    EMB_DIM = args.EMB_DIM
    HID_DIM = args.HID_DIM
    KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT

    STARTING_TEMP = args.STARTING_TEMP
    TEMP_MIN = args.TEMP_MIN
    ANNEAL_RATE = args.ANNEAL_RATE

    NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE

    #     transform = Compose(
    #         [
    #             RandomResizedCrop(args.image_size, args.image_size),
    #             OneOf(
    #                 [
    #                     IAAAdditiveGaussianNoise(),
    #                     GaussNoise(),
    #                 ],
    #                 p=0.2
    #             ),
    #             VerticalFlip(p=0.5),
    #             OneOf(
    #                 [
    #                     MotionBlur(p=.2),
    #                     MedianBlur(blur_limit=3, p=0.1),
    #                     Blur(blur_limit=3, p=0.1),
    #                 ],
    #                 p=0.2
    #             ),
    #             OneOf(
    #                 [
    #                     CLAHE(clip_limit=2),
    #                     IAASharpen(),
    #                     IAAEmboss(),
    #                     RandomBrightnessContrast(),
    #                 ],
    #                 p=0.3
    #             ),
    #             HueSaturationValue(p=0.3),
    # #             Normalize(
    # #                 mean=[0.485, 0.456, 0.406],
    # #                 std=[0.229, 0.224, 0.225],
    # #             )
    #         ],
    #         p=1.0
    #     )

    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])

    sampler = None
    dl = None

    # data
    logger.debug(f"IMAGE_PATH : {IMAGE_PATH}")
    #     ds = AlbumentationImageDataset(
    #         IMAGE_PATH,
    #         transform=transform,
    #         args=args
    #     )
    ds = ImageFolder(
        IMAGE_PATH,
        transform=transform,
    )

    if args.model_parallel and (args.ddp
                                or args.horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        ds = SplitDataset(ds, partitions=partitions_dict)
        ds.select(f"{smp.dp_rank()}")

    dl = DataLoader(ds,
                    BATCH_SIZE,
                    shuffle=True,
                    drop_last=args.model_parallel,
                    **args.kwargs)

    vae_params = dict(image_size=IMAGE_SIZE,
                      num_layers=NUM_LAYERS,
                      num_tokens=NUM_TOKENS,
                      codebook_dim=EMB_DIM,
                      hidden_dim=HID_DIM,
                      num_resnet_blocks=NUM_RESNET_BLOCKS)

    vae = DiscreteVAE(**vae_params,
                      smooth_l1_loss=SMOOTH_L1_LOSS,
                      kl_div_loss_weight=KL_LOSS_WEIGHT).to(device)
    # optimizer

    opt = Adam(vae.parameters(), lr=LEARNING_RATE)
    sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE)

    if args.model_parallel:
        import copy
        dummy_codebook = copy.deepcopy(vae.codebook)
        dummy_decoder = copy.deepcopy(vae.decoder)

        vae = smp.DistributedModel(vae)
        scaler = smp.amp.GradScaler()
        opt = smp.DistributedOptimizer(opt)

        if args.partial_checkpoint:
            args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
        elif args.full_checkpoint:
            args.checkpoint = smp.load(args.full_checkpoint, partial=False)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])

    assert len(ds) > 0, 'folder does not contain any images'

    if (not args.model_parallel) and args.rank == 0:
        print(f'{len(ds)} images found for training')

        # weights & biases experiment tracking

        #         import wandb

        model_config = dict(num_tokens=NUM_TOKENS,
                            smooth_l1_loss=SMOOTH_L1_LOSS,
                            num_resnet_blocks=NUM_RESNET_BLOCKS,
                            kl_loss_weight=KL_LOSS_WEIGHT)

#         run = wandb.init(
#             project = 'dalle_train_vae',
#             job_type = 'train_model',
#             config = model_config
#         )

    def save_model(path):
        if not args.rank == 0:
            return

        save_obj = {'hparams': vae_params, 'weights': vae.state_dict()}

        torch.save(save_obj, path)

    # distribute with deepspeed
    if not args.model_parallel:
        deepspeed_utils.check_batch_size(BATCH_SIZE)
        deepspeed_config = {'train_batch_size': BATCH_SIZE}

        (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute(
            args=args,
            model=vae,
            optimizer=opt,
            model_parameters=vae.parameters(),
            training_data=ds if args.deepspeed else dl,
            lr_scheduler=sched,
            config_params=deepspeed_config,
        )

    try:
        # Rubik: Define smp.step. Return any tensors needed outside.
        @smp.step
        def train_step(vae, images, temp):
            #             logger.debug(f"args.amp : {args.amp}")
            with autocast(enabled=(args.amp > 0)):
                loss, recons = vae(images,
                                   return_loss=True,
                                   return_recons=True,
                                   temp=temp)

            scaled_loss = scaler.scale(loss) if args.amp else loss
            vae.backward(scaled_loss)
            #             torch.nn.utils.clip_grad_norm_(vae.parameters(), 5)
            return loss, recons

        @smp.step
        def get_codes_step(vae, images, k):
            images = images[:k]
            logits = vae.forward(images, return_logits=True)
            codebook_indices = logits.argmax(dim=1).flatten(1)
            return codebook_indices

        def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices):
            from functools import partial
            for module in dummy_codebook.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            image_embeds = dummy_codebook.forward(codebook_indices)
            b, n, d = image_embeds.shape
            h = w = int(sqrt(n))

            image_embeds = rearrange(image_embeds,
                                     'b (h w) d -> b d h w',
                                     h=h,
                                     w=w)
            for module in dummy_decoder.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            hard_recons = dummy_decoder.forward(image_embeds)
            return hard_recons

    except:
        pass

    # starting temperature

    global_step = 0
    temp = STARTING_TEMP

    for epoch in range(EPOCHS):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(dl), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        vae.train()
        start = time.time()

        for i, (images, _) in enumerate(dl):
            images = images.to(device, non_blocking=True)
            opt.zero_grad()

            if args.model_parallel:
                loss, recons = train_step(vae, images, temp)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()
                recons = recons.reduce_mean()
            else:
                loss, recons = distr_vae(images,
                                         return_loss=True,
                                         return_recons=True,
                                         temp=temp)

            if (not args.model_parallel) and args.deepspeed:
                # Gradients are automatically zeroed after the step
                distr_vae.backward(loss)
                distr_vae.step()
            elif args.model_parallel:
                if args.amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(vae.local_parameters())) > 0:
                        opt.step()
            else:
                loss.backward()
                opt.step()

            logs = {}

            if i % 10 == 0:
                if args.rank == 0:
                    #                 if deepspeed_utils.is_root_worker():
                    k = NUM_IMAGES_SAVE

                    with torch.no_grad():
                        if args.model_parallel:
                            model_dict = vae.state_dict()
                            model_dict_updated = {}
                            for key, val in model_dict.items():
                                if "decoder" in key:
                                    key = key.replace("decoder.", "")
                                elif "codebook" in key:
                                    key = key.replace("codebook.", "")
                                model_dict_updated[key] = val

                            dummy_decoder.load_state_dict(model_dict_updated,
                                                          strict=False)
                            dummy_codebook.load_state_dict(model_dict_updated,
                                                           strict=False)
                            codes = get_codes_step(vae, images, k)
                            codes = codes.reduce_mean().to(torch.long)
                            hard_recons = hard_recons_step(
                                dummy_decoder, dummy_codebook, codes)
                        else:
                            codes = vae.get_codebook_indices(images[:k])
                            hard_recons = vae.decode(codes)

                    images, recons = map(lambda t: t[:k], (images, recons))
                    images, recons, hard_recons, codes = map(
                        lambda t: t.detach().cpu(),
                        (images, recons, hard_recons, codes))
                    images, recons, hard_recons = map(
                        lambda t: make_grid(t.float(),
                                            nrow=int(sqrt(k)),
                                            normalize=True,
                                            range=(-1, 1)),
                        (images, recons, hard_recons))

#                     logs = {
#                         **logs,
#                         'sample images':        wandb.Image(images, caption = 'original images'),
#                         'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
#                         'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
#                         'codebook_indices':     wandb.Histogram(codes),
#                         'temperature':          temp
#                     }

                if args.model_parallel:
                    filename = f'{args.model_dir}/vae.pt'
                    if smp.dp_rank == 0:
                        if args.save_full_model:
                            model_dict = vae.state_dict()
                            opt_dict = opt.state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=False,
                            )
                        else:
                            model_dict = vae.local_state_dict()
                            opt_dict = opt.local_state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=True,
                            )
                    smp.barrier()

                else:
                    save_model(f'{args.model_dir}/vae.pt')
    #                     wandb.save(f'{args.model_dir}/vae.pt')

    # temperature anneal

                temp = max(temp * math.exp(-ANNEAL_RATE * global_step),
                           TEMP_MIN)

                # lr decay

                sched.step()

            # Collective loss, averaged
            if args.model_parallel:
                avg_loss = loss.detach().clone()
                #                 print("args.world_size : {}".format(args.world_size))
                avg_loss /= args.world_size

            else:
                avg_loss = deepspeed_utils.average_all(loss)

            if args.rank == 0:
                if i % 100 == 0:
                    lr = sched.get_last_lr()[0]
                    print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},')

                    logs = {
                        **logs, 'epoch': epoch,
                        'iter': i,
                        'loss': avg_loss.item(),
                        'lr': lr
                    }

#                 wandb.log(logs)
            global_step += 1

            if args.rank == 0:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                # Measure accuracy
                #                 prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

                # to_python_float incurs a host<->device sync
                losses.update(util.to_python_float(loss), images.size(0))
                #                 top1.update(util.to_python_float(prec1), images.size(0))
                #                 top5.update(util.to_python_float(prec5), images.size(0))

                # Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - start) / args.log_interval)
                end = time.time()

                print(
                    'Epoch: [{0}][{1}/{2}] '
                    'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, '
                    'Train_Speed={3:.3f} ({4:.3f}), '
                    'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format(
                        epoch,
                        i,
                        len(dl),
                        args.world_size * BATCH_SIZE / batch_time.val,
                        args.world_size * BATCH_SIZE / batch_time.avg,
                        batch_time=batch_time,
                        loss=losses))

#         if deepspeed_utils.is_root_worker():
# save trained model to wandb as an artifact every epoch's end

#             model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#             model_artifact.add_file(f'{args.model_dir}/vae.pt')
#             run.log_artifact(model_artifact)

    if args.rank == 0:
        #     if deepspeed_utils.is_root_worker():
        # save final vae and cleanup
        if args.model_parallel:
            logger.debug('save model_parallel')
        else:
            save_model(os.path.join(args.model_dir, 'vae-final.pt'))


#         wandb.save(f'{args.model_dir}/vae-final.pt')

#         model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#         model_artifact.add_file(f'{args.model_dir}/vae-final.pt')
#         run.log_artifact(model_artifact)

#         wandb.finish()

    if args.model_parallel:
        if args.assert_losses:
            if args.horovod or args.ddp:
                # SM Distributed: If using data parallelism, gather all losses across different model
                # replicas and check if losses match.

                losses = smp.allgather(loss, smp.DP_GROUP)
                for l in losses:
                    print(l)
                    assert math.isclose(l, losses[0])

                assert loss < 0.18
            else:
                assert loss < 0.08

        smp.barrier()
        print("SMP training finished successfully")