Пример #1
0
 def _wrap_model(self, model, training=True):
     if self.is_model_parallel_enabled:
         # Wrapping the base model twice in a DistributedModel will raise an error.
         if isinstance(self.model_wrapped, smp.model.DistributedModel):
             return self.model_wrapped
         return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
     else:
         return super()._wrap_model(model)
 def _wrap_model(self, model, training=True):
     if self.is_model_parallel_enabled:
         # Wrapping the base model twice in a DistributedModel will raise an error.
         if isinstance(self.model_wrapped, smp.model.DistributedModel):
             return self.model_wrapped
         return smp.DistributedModel(model)
     else:
         return super()._wrap_model(model)
Пример #3
0
def smp_init(model, optimizer, args):
    model = smp.DistributedModel(model)
    args.scaler = smp.amp.GradScaler()
    optimizer = smp.DistributedOptimizer(optimizer)
    if args.partial_checkpoint:
        args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
        model.load_state_dict(args.checkpoint["model_state_dict"])
        optimizer.load_state_dict(args.checkpoint["optimizer_state_dict"])
    elif args.full_checkpoint:
        args.checkpoint = smp.load(args.full_checkpoint, partial=False)
        model.load_state_dict(args.checkpoint["model_state_dict"])
        optimizer.load_state_dict(args.checkpoint["optimizer_state_dict"])

    return model, optimizer, args
Пример #4
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")
    use_ddp = args.ddp > 0
    use_horovod = args.horovod > 0

    # Fix seeds in order to get the same losses across runs
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    cfg = {
        "microbatches": args.num_microbatches,
        "placement_strategy": "spread",
        "pipeline": args.pipeline,
        "optimize": "speed",
        "partitions": args.num_partitions,
        "horovod": use_horovod,
        "ddp": use_ddp,
    }

    smp.init(cfg)

    # SM Distributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")
    kwargs = {"batch_size": args.batch_size}
    kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False})

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # SM Distributed: Download only on a single process per instance.
    # When this is not present, the file is corrupted by multiple processes trying
    # to download and extract at the same time
    if smp.local_rank() == 0:
        dataset1 = datasets.MNIST("../data",
                                  train=True,
                                  download=True,
                                  transform=transform)
    smp.barrier()
    dataset1 = datasets.MNIST("../data",
                              train=True,
                              download=False,
                              transform=transform)

    if (use_ddp or use_horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        dataset1 = SplitDataset(dataset1, partitions=partitions_dict)
        dataset1.select(f"{smp.dp_rank()}")

    # Download and create dataloaders for train and test dataset
    dataset2 = datasets.MNIST("../data", train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    model = GroupedNet()

    # SMP handles the transfer of parameters to the right device
    # and the user doesn't need to call 'model.to' explicitly.
    # model.to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    # SM Distributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    model = smp.DistributedModel(model)
    scaler = smp.amp.GradScaler()
    optimizer = smp.DistributedOptimizer(optimizer)

    if args.partial_checkpoint:
        checkpoint = smp.load(args.partial_checkpoint, partial=True)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    elif args.full_checkpoint:
        checkpoint = smp.load(args.full_checkpoint, partial=False)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, scaler, device, train_loader, optimizer, epoch)
        test_loss = test(args, model, device, test_loader)
        scheduler.step()

    if args.save_partial_model:
        if smp.dp_rank() == 0:
            model_dict = model.local_state_dict()
            opt_dict = optimizer.local_state_dict()
            smp.save(
                {
                    "model_state_dict": model_dict,
                    "optimizer_state_dict": opt_dict
                },
                f"./pt_mnist_checkpoint.pt",
                partial=True,
            )

    if args.save_full_model:
        if smp.dp_rank() == 0:
            model_dict = model.state_dict()
            opt_dict = optimizer.state_dict()
            smp.save(
                {
                    "model_state_dict": model_dict,
                    "optimizer_state_dict": opt_dict
                },
                "./pt_mnist_checkpoint.pt",
                partial=False,
            )

    # Waiting the save checkpoint to be finished before run another allgather_object
    smp.barrier()

    if args.assert_losses:
        if use_horovod or use_ddp:
            # SM Distributed: If using data parallelism, gather all losses across different model
            # replicas and check if losses match.

            losses = smp.allgather(test_loss, smp.DP_GROUP)
            for l in losses:
                assert math.isclose(l, losses[0])

            assert test_loss < 0.18
        else:
            assert test_loss < 0.08
Пример #5
0
def main():
    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")
    use_ddp = True
    use_horovod = False

    # Fix seeds in order to get the same losses across runs
    random.seed(1)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    smp.init()

    # SM Distributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")
    kwargs = {"batch_size": 64}
    kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False})

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # SM Distributed: Download only on a single process per instance.
    # When this is not present, the file is corrupted by multiple processes trying
    # to download and extract at the same time
    if smp.local_rank() == 0:
        dataset1 = datasets.MNIST("../data",
                                  train=True,
                                  download=True,
                                  transform=transform)
    smp.barrier()
    dataset1 = datasets.MNIST("../data",
                              train=True,
                              download=False,
                              transform=transform)

    if (use_ddp or use_horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        dataset1 = SplitDataset(dataset1, partitions=partitions_dict)
        dataset1.select(f"{smp.dp_rank()}")

    # Download and create dataloaders for train and test dataset
    dataset2 = datasets.MNIST("../data", train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    model = GroupedNet()

    # SMP handles the transfer of parameters to the right device
    # and the user doesn't need to call 'model.to' explicitly.
    # model.to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=4.0)

    # SM Distributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    model = smp.DistributedModel(model)
    scaler = smp.amp.GradScaler()
    optimizer = smp.DistributedOptimizer(optimizer)

    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    for epoch in range(1, 2):
        train(model, scaler, device, train_loader, optimizer, epoch)
        test_loss = test(model, device, test_loader)
        scheduler.step()

    if smp.rank() == 0:
        if os.path.exists("/opt/ml/local_checkpoints"):
            print("-INFO- PATH DO EXIST")
        else:
            os.makedirs("/opt/ml/local_checkpoints")
            print("-INFO- PATH DO NOT EXIST")

    # Waiting the save checkpoint to be finished before run another allgather_object
    smp.barrier()

    if smp.dp_rank() == 0:
        model_dict = model.local_state_dict()
        opt_dict = optimizer.local_state_dict()
        smp.save(
            {
                "model_state_dict": model_dict,
                "optimizer_state_dict": opt_dict
            },
            f"/opt/ml/local_checkpoints/pt_mnist_checkpoint.pt",
            partial=True,
        )
    smp.barrier()

    if smp.local_rank() == 0:
        print("Start syncing")
        base_s3_path = os.path.dirname(
            os.path.dirname(os.getenv("SM_MODULE_DIR", "")))
        curr_host = os.getenv("SM_CURRENT_HOST")
        full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/"
        sync_local_checkpoints_to_s3(local_path="/opt/ml/local_checkpoints",
                                     s3_path=full_s3_path)
        print("Finished syncing")
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    if args.use_sequential > 0:
        config.use_sequential = True
    else:
        config.use_sequential = False

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)
    model.checkpoint_activations(args.checkpoint_activations)
    if args.smp > 0:
        # SMP: Use the DistributedModel container to provide the model
        # to be partitioned across different ranks. For the rest of the script,
        # the returned DistributedModel object should be used in place of
        # the model provided for DistributedModel class instantiation.
        model = smp.DistributedModel(model)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if not args.init_checkpoint:
            if not args.s3_checkpoint_uri:
                raise ValueError(
                    "Need to set s3_checkpoint_uri, if init_checkpoint not set"
                )
            if smp.local_rank() == 0:
                sync_s3_checkpoints_to_local(args.output_dir,
                                             args.s3_checkpoint_uri)
            smp.barrier()
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if ".pt" in f
            ]
            args.resume_step = max([
                int(x.split(".pt")[0].split("_")[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        # SMP: Load a model that was saved with smp.save
        if not args.init_checkpoint:
            checkpoint = smp.load(
                os.path.join(args.output_dir,
                             "ckpt_{}.pt".format(global_step)),
                partial=args.partial_checkpoint,
            )
        else:
            checkpoint = smp.load(args.init_checkpoint)

        model.load_state_dict(checkpoint["model"], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "gamma", "beta", "LayerNorm"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01,
        },
        {
            "params":
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0,
        },
    ]

    optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)
    if args.smp > 0:
        # SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model
        # Also provides APIs to obtain local optimizer state for the current mp_rank.
        optimizer = smp.DistributedOptimizer(optimizer)
    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps)

    if args.fp16:
        if args.loss_scale == 0:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level="O2",
                loss_scale="dynamic",
                cast_model_outputs=torch.float16,
            )
        else:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level="O2",
                loss_scale=args.loss_scale,
                cast_model_outputs=torch.float16,
            )
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint["optimizer"]["state"].keys())
            # Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint["optimizer"]["state"][key]["step"] = global_step
            for iter, item in enumerate(
                    checkpoint["optimizer"]["param_groups"]):
                checkpoint["optimizer"]["param_groups"][iter][
                    "step"] = global_step
                checkpoint["optimizer"]["param_groups"][iter][
                    "t_total"] = args.max_steps
                checkpoint["optimizer"]["param_groups"][iter][
                    "warmup"] = args.warmup_proportion
                checkpoint["optimizer"]["param_groups"][iter][
                    "lr"] = args.learning_rate
        optimizer.load_state_dict(checkpoint["optimizer"])  # , strict=False)
        # Restore AMP master parameters
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint["optimizer"])
            for param, saved_param in zip(amp.master_params(optimizer),
                                          checkpoint["master params"]):
                param.data.copy_(saved_param.data)

    # if args.local_rank != -1:
    #    if not args.allreduce_post_accumulation:
    #        model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
    #    else:
    #        flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
    # elif args.n_gpu > 1:
    #    model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
Пример #7
0
def main(args):
    # constants

    VAE_PATH = args.vae_path
    DALLE_PATH = args.dalle_path
    RESUME = exists(DALLE_PATH)

    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size

    LEARNING_RATE = args.learning_rate
    GRAD_CLIP_NORM = args.clip_grad_norm
    LR_DECAY = args.lr_decay

    MODEL_DIM = args.dim
    TEXT_SEQ_LEN = args.text_seq_len
    DEPTH = args.depth
    HEADS = args.heads
    DIM_HEAD = args.dim_head
    REVERSIBLE = args.reversible
    LOSS_IMG_WEIGHT = args.loss_img_weight

    ATTN_TYPES = tuple(args.attn_types.split(','))

    DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'

    # initialize distributed backend

    # initialize distributed backend
    if args.sagemakermp:
        args.deepspeed = False
        using_deepspeed = False
    else:
        args.deepspeed = True
    
    distr_backend = distributed_utils.set_backend_from_args(args)
    distr_backend.initialize(args)

    if args.sagemakermp:
        args = smp_init(args)
        distributed_utils.using_backend(distributed_utils.SageMakerMPBackend)
    else:
        using_deepspeed = \
            distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) 
        args.rank = int(os.environ.get('RANK'))
        args.world_size = int(os.environ.get('WORLD_SIZE'))
        args.local_rank = int(os.environ.get('LOCAL_RANK'))
        args.global_rank = args.rank
    
    logger.debug(f"using_deepspeed : {using_deepspeed}")
    
    logger.debug(
        f"args.local_rank : {args.local_rank}, args.rank : {args.rank}")
    

    # tokenizer
    logger.debug(f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}")
    if exists(args.bpe_path):
        klass = HugTokenizer if args.hug else YttmTokenizer
        tokenizer = klass(args.bpe_path)
    elif args.chinese:
        tokenizer = ChineseTokenizer()
    else:
        tokenizer = SimpleTokenizer()

    # reconstitute vae

    if RESUME:
        dalle_path = Path(DALLE_PATH)
        if using_deepspeed:
            cp_dir = cp_path_to_dir(dalle_path, 'ds')
            assert cp_dir.is_dir(), \
                f'DeepSpeed checkpoint directory {cp_dir} not found'
            dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME
        else:
            assert dalle_path.exists(), 'DALL-E model file does not exist'
        loaded_obj = torch.load(str(dalle_path), map_location='cpu')

        dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[
            'vae_params'], loaded_obj['weights']

        if vae_params is not None:
            vae = DiscreteVAE(**vae_params)
        else:
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        dalle_params = dict(**dalle_params)
        IMAGE_SIZE = vae.image_size
    else:
        if exists(VAE_PATH):
            vae_path = Path(VAE_PATH)
            assert vae_path.exists(), 'VAE model file does not exist'
            assert not vae_path.is_dir(), \
                ('Cannot load VAE model from directory; please use a '
                'standard *.pt checkpoint. '
                'Currently, merging a DeepSpeed-partitioned VAE into a DALLE '
                'model is not supported.')

            loaded_obj = torch.load(str(vae_path))

            vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

            vae = DiscreteVAE(**vae_params)
            vae.load_state_dict(weights)
        else:
            if args.rank == 0:
                # if distr_backend.is_root_worker():
                print('using pretrained VAE for encoding images to tokens')
            vae_params = None
            logger.debug(f"************* args.taming : {args.taming}")
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        IMAGE_SIZE = vae.image_size

        dalle_params = dict(
            num_text_tokens=tokenizer.vocab_size,
            text_seq_len=TEXT_SEQ_LEN,
            dim=MODEL_DIM,
            depth=DEPTH,
            heads=HEADS,
            dim_head=DIM_HEAD,
            reversible=REVERSIBLE,
            loss_img_weight=LOSS_IMG_WEIGHT,
            attn_types=ATTN_TYPES,
        )

    # configure OpenAI VAE for float16s

    if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
        vae.enc.blocks.output.conv.use_float16 = True

    # create dataset and dataloader

    is_shuffle = not distributed_utils.using_backend(
        distributed_utils.HorovodBackend)

    ds = TextImageDataset(
        args.image_text_folder,
        text_len=TEXT_SEQ_LEN,
        image_size=IMAGE_SIZE,
        resize_ratio=args.resize_ratio,
        truncate_captions=args.truncate_captions,
        tokenizer=tokenizer,
        shuffle=is_shuffle,
    )

    assert len(ds) > 0, 'dataset is empty'
    # if distr_backend.is_root_worker():
    if args.rank == 0:
        print(f'{len(ds)} image-text pairs found for training')

    if not is_shuffle:
        data_sampler = torch.utils.data.distributed.DistributedSampler(
            ds, num_replicas=args.world_size, rank=args.rank)
    elif args.sagemakermp:
        args.ds = ds
        ds = split_dataset(args)
        data_sampler = None
    else:
        data_sampler = None
        
        
    print(f"data_sampler : {data_sampler}")

    # uncorrectable NVLink error was detected during the execution  --> remove
    kwargs = {'num_workers': args.num_worker, 'pin_memory': True}
    dl = DataLoader(ds,
                    batch_size=BATCH_SIZE,
                    shuffle=is_shuffle,
                    drop_last=True,
                    sampler=data_sampler,
                    **kwargs
                   )
    
    logger.info("Processes {}/{} ({:.0f}%) of train data".format(
    len(dl.sampler), len(dl.dataset),
    100. * len(dl.sampler) / len(dl.dataset)))

    # initialize DALL-E
    dalle = DALLE(vae=vae, **dalle_params)
    if not using_deepspeed:
        if args.fp16:
            dalle = dalle.half()
        dalle = dalle.cuda()

    if RESUME and not using_deepspeed:
        dalle.load_state_dict(weights)

    # optimizer
    opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
    

    if LR_DECAY:
        scheduler = ReduceLROnPlateau(
            opt,
            mode="min",
            factor=0.5,
            patience=10,
            cooldown=10,
            min_lr=1e-6,
            verbose=True,
        )
    # if distr_backend.is_root_worker():
    if args.global_rank == 0:
        # experiment tracker

        model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD)
        
        
        logger.debug(f"args.wandb_name : {args.wandb_name}, RESUME : {RESUME}")
        
        run = wandb.init(
            project=args.wandb_name,  # 'dalle_train_transformer' by default
            resume=RESUME,
            config=model_config,
        )

    # distribute
    distr_backend.check_batch_size(BATCH_SIZE)
    deepspeed_config = {
        'train_batch_size': BATCH_SIZE,
        'gradient_clipping': GRAD_CLIP_NORM,
        'fp16': {
            'enabled': args.fp16,
        },
    }
    (distr_dalle, distr_opt, distr_dl,
     distr_scheduler) = distr_backend.distribute(
         args=args,
         model=dalle,
         optimizer=opt,
         model_parameters=get_trainable_params(dalle),
         training_data=ds if using_deepspeed else dl,
         lr_scheduler=scheduler if LR_DECAY else None,
         config_params=deepspeed_config,
     )
    avoid_model_calls = using_deepspeed and args.fp16

    if args.sagemakermp:
        args.distr_dalle = smp.DistributedModel(distr_dalle)
        args.scaler = smp.amp.GradScaler()
        args.distr_opt = smp.DistributedOptimizer(distr_opt)
    
    if RESUME and using_deepspeed:
        distr_dalle.load_checkpoint(str(cp_dir))
        


    # training

    for epoch in range(EPOCHS):
        logger.debug(f"********* epoch : {epoch} **********")
        if data_sampler:
            data_sampler.set_epoch(epoch)
            
        for i, (text, images) in enumerate(distr_dl):
            if args.fp16:
                images = images.half()
                
            text, images = map(lambda t: t.cuda(), (text, images))
            
            if args.sagemakermp:
                args.distr_opt.zero_grad()
                
                loss = train_step(args, text, images, return_loss=True)
                loss = loss.reduce_mean()

            else:  
                loss = distr_dalle(text, images, return_loss=True, args=args)

            if using_deepspeed:
                distr_dalle.backward(loss)
                distr_dalle.step()
                # Gradients are automatically zeroed after the step
            elif args.sagemakermp:
                if args.amp:
                    scaler.step(args.distr_opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(args.distr_dalle.local_parameters())) > 0:
                        args.distr_opt.step()
            else:
                loss.backward()
                clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM)
                distr_opt.step()
                distr_opt.zero_grad()

            # Collective loss, averaged
            avg_loss = distr_backend.average_all(loss)

            log = {}

            # if i % 10 == 0 and distr_backend.is_root_worker():
            if i % 10 == 0 and args.rank == 0:
                print(epoch, i, f'loss - {avg_loss.item()}')

                log = {
                    **log, 'epoch': epoch,
                    'iter': i,
                    'loss': avg_loss.item()
                }

            if i % 100 == 0:
                # if distr_backend.is_root_worker():
                if args.rank == 0:
                    sample_text = text[:1]
                    token_list = sample_text.masked_select(
                        sample_text != 0).tolist()
                    decoded_text = tokenizer.decode(token_list)
                    
                    logger.debug(f"******* avoid_model_calls : {avoid_model_calls}")
                    if not avoid_model_calls:
                        # CUDA index errors when we don't guard this
                        image = dalle.generate_images(
                            text[:1], filter_thres=0.9)  # topk sampling at 0.9

                    wandb.save(f'./dalle.pt')

                    log = {
                        **log,
                    }
                    if not avoid_model_calls:
                        log['image'] = wandb.Image(image, caption=decoded_text)

                args.distr_dalle = distr_dalle
                args.dalle_params = dalle_params
                args.vae_params = vae_params
                args.using_deepspeed = using_deepspeed
                args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME

                save_model(args, f'{args.model_dir}/dalle.pt')

            # if distr_backend.is_root_worker():
            if args.rank == 0:
                wandb.log(log)
                
#             text, images = prefetcher.next()

        if LR_DECAY and not using_deepspeed:
            # Scheduler is automatically progressed after the step when
            # using DeepSpeed.
            distr_scheduler.step(loss)

        # if distr_backend.is_root_worker():
        if args.global_rank == 0:
            # save trained model to wandb as an artifact every epoch's end

            model_artifact = wandb.Artifact('trained-dalle',
                                            type='model',
                                            metadata=dict(model_config))
            # model_artifact.add_file('dalle.pt')
            run.log_artifact(model_artifact)

    args.distr_dalle = distr_dalle
    args.dalle_params = dalle_params
    args.vae_params = vae_params
    args.using_deepspeed = using_deepspeed
    args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME

    save_model(args, f'{args.model_dir}/dalle-final.pt')

    # if distr_backend.is_root_worker():
    if args.global_rank == 0:
        wandb.save('./dalle-final.pt')
        model_artifact = wandb.Artifact('trained-dalle',
                                        type='model',
                                        metadata=dict(model_config))
        # model_artifact.add_file('dalle-final.pt')
        run.log_artifact(model_artifact)

        wandb.finish()
Пример #8
0
def main():
    args = parse_args()

    if args.shard_optimizer_state > 0 and not args.skip_full_optimizer:
        raise ValueError(
            "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding."
        )

    if args.partition_assignment != "" and args.manual_partition == 0:
        print("[Warning] partition_assignment is set, enable manual_partition")
        args.manual_partition = 1

    # any value here is overriden by the config set in notebook when launching the sagemaker job
    smp_config = {
        "ddp": True,
        "tensor_parallel_degree": args.tensor_parallel_degree,
        "pipeline_parallel_degree": args.pipeline_parallel_degree,
        "microbatches": args.microbatches,
        # if activation_checkpointing true checkpoints transformer layers below
        "checkpoint_attentions":
        False if args.activation_checkpointing else True,
        "shard_optimizer_state": args.shard_optimizer_state > 0,
        "prescaled_batch": args.prescaled_batch > 0,
        "offload_activations": args.offload_activations > 0,
        "optimize": args.optimize,
        "auto_partition": False if args.manual_partition else True,
        "default_partition": 0,
        "static_mode": args.static_mode > 0,
        "fast_mode": args.fast_mode > 0,
    }

    if args.smp_version < 110:
        smp_config["fp16_params"] = args.fp16 > 0
    else:
        smp_config["fp16"] = args.fp16 > 0
        smp_config["delayed_parameter_initialization"] = args.delayed_param > 0
        smp_config["placement_strategy"] = args.placement_strategy
        smp_config[
            "activation_loading_horizon"] = args.activation_loading_horizon
        smp_config["skip_tracing"] = args.skip_tracing > 0

    if args.active_microbatches is not None:
        smp_config["active_microbatches"] = args.active_microbatches

    smp.init(smp_config)

    if smp.rank() == 0:
        print("Arguments:", args.__dict__)
        print(f"Transformers version: {transformers.__version__}")
        print(
            f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}"
        )
        print(f"smdistributed config: {smp_config}")

    if args.save_final_full_model and smp.rank() == 0:
        print(
            f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints."
        )

    if args.partition_assignment != "":
        partition_assignment = args.partition_assignment.split(",")
        assert (
            len(partition_assignment) == smp.pp_size()
        ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}"

    if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0):
        for path in [args.model_dir, args.checkpoint_dir]:
            if not os.path.exists(path):
                os.makedirs(path, exist_ok=True)

    model_config = GPT2Config(
        vocab_size=args.vocab_size,
        n_positions=args.max_context_width,
        n_embd=args.hidden_width,
        n_layer=args.num_layers,
        n_head=args.num_heads,
        n_inner=None,
        activation_function="gelu_new",
        resid_pdrop=args.resid_pdrop,
        embd_pdrop=args.embd_pdrop,
        attn_pdrop=args.attn_pdrop,
        layer_norm_epsilon=1e-05,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=args.summary_first_pdrop,
        # gradient_checkpointing=args.gradient_checkpointing > 0,
        use_cache=False,
        bos_token_id=50256,
        eos_token_id=50256,
        return_dict=True,
    )

    # the following improves start-up time by skipping proper initialization
    # of weights in the original model. this is not a problem because DistributedModel
    # will override those weights anyway when tensor_parallel_degree > 1.
    if smp.tp_size() > 1:
        from transformers.modeling_utils import PreTrainedModel

        PreTrainedModel.init_weights = lambda x: None

    set_seed(args.seed)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before model creation")

    if args.smp_version < 110:
        if args.fp16:
            torch.set_default_dtype(torch.float16)
        with smp.tensor_parallelism(
                enabled=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0):
            with smp.delay_param_initialization(
                    enabled=(smp.tp_size() > 1 and args.delayed_param > 0)):
                model = AutoModelForCausalLM.from_config(model_config)
    else:
        with smp.model_creation(
                tensor_parallelism=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0,
                query_key_layer_scaling=args.query_key_layer_scaling > 0,
                fused_softmax=args.fused_softmax > 0,
                fused_bias_gelu=args.fused_bias_gelu > 0,
                dtype=torch.float16
                if args.fp16 else torch.get_default_dtype(),
        ):
            model = AutoModelForCausalLM.from_config(model_config)

    if args.smp_version < 110 and args.fp16:
        model = FP16_Module(model)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after model creation")

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    if smp.rank() == 0:
        print(f"# total parameters: {num_params}")

    # smdistributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")

    if not args.same_seed:
        # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
        set_seed(args.seed + smp.tp_rank())

    # smdistributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    if args.smp_version < 110 and args.fp16:
        torch.set_default_dtype(torch.float16)
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before dist model creation")
    model = smp.DistributedModel(model, trace_device="gpu")
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after dist model creation")

    if args.smp_version < 110:
        if smp.tp_size() > 1:
            transformer_layers = model.module.module.module.transformer.seq_layers
        else:
            transformer_layers = model.module.module.module.transformer.h
    else:
        m = model.get_module()
        if smp.tp_size() > 1:
            transformer_layers = m.transformer.seq_layers
        else:
            transformer_layers = m.transformer.h

    if args.manual_partition:
        print(f"Manual partition enabled")
        if args.partition_assignment != "":
            get_num_layers = lambda x: int(partition_assignment[x])
            total_layers = sum(
                [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())])
            assert (
                total_layers == args.num_layers
            ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}"
        else:
            # evenly distribute layers across all partitions
            div, rem = divmod(args.num_layers, smp.pp_size())
            get_num_layers = lambda x: (div + 1
                                        if x >= smp.pp_size() - rem else div)
        assignments = []
        # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18"
        # Need further investigation
        # for pp_rank in reversed(range(smp.pp_size())):
        for pp_rank in range(smp.pp_size()):
            nl = get_num_layers(pp_rank)
            print(f"{nl} layers assigned to partition {pp_rank}")
            assignments += [pp_rank for _ in range(nl)]

        for i, c in enumerate(transformer_layers.children()):
            smp.set_partition(c, assignments[i])
    if args.smp_version < 110:
        iter_model = model
        # Build parameter groups (weight decay and non-decay).
        while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
            iter_model = iter_model.module
    else:
        iter_model = m
    param_groups = get_param_groups_by_weight_decay(iter_model)

    if args.use_adamw > 0:
        optimizer = optim.AdamW(param_groups,
                                betas=(args.beta1, args.beta2),
                                lr=args.lr,
                                weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adam(param_groups,
                               betas=(args.beta1, args.beta2),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    if args.activation_checkpointing:
        kwargs = {}
        if isinstance(transformer_layers, nn.Sequential):
            kwargs["pack_args_as_tuple"] = True
            kwargs["strategy"] = args.activation_strategy
        smp.set_activation_checkpointing(transformer_layers, **kwargs)

    if args.smp_version < 110:
        optimizer = FP16_Optimizer(
            model,
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            use_smp=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
            params_have_main_grad=False,
            shard_optimizer_state=args.shard_optimizer_state > 0,
        )

        optimizer = smp.DistributedOptimizer(optimizer)
        model.register_post_step_hook(
            lambda model, optimizer: optimizer.init_master_params())
    else:
        optimizer = smp.DistributedOptimizer(
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
        )
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if args.enable_memory_profiling > 0:
        model.register_post_partition_hook(
            lambda model, optimizer: memory_status(msg="After_partition"))

    # load after wrapping model and optimizer with smp Distributed...
    if args.load_full or args.load_partial:
        if args.load_partial and args.load_full:
            print(
                "Since both --load_partial and --load_full set, will try to load from full checkpoint."
                "If the intention is to load from partial checkpoint, please don't set --load_full"
            )
        partial = not args.load_full
        path = args.checkpoint_dir if partial else args.model_dir
        translate_from_hf = not partial
        model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer(
            path,
            model,
            optimizer,
            lr_scheduler,
            partial,
            args,
            translate_from_hf=translate_from_hf,
            seq_length=args.max_context_width,
            load_model=True,
            load_optimizer=args.load_partial > 0,
            num_params=num_params,
        )
    else:
        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

    start = time.time()
    total_steps, throughput, loss = train(
        model,
        optimizer,
        lr_scheduler,
        model_config,
        start_train_path_index,
        start_batch_index,
        num_params,
        total_steps,
        args,
    )
    time_to_train = time.time() - start
    if args.ci:
        print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}")
        print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}")
        print(f"[SMP_METRIC]__GPT2__Loss__{loss}")
        if not args.load_partial and not args.load_full:
            assert time_to_train < args.time_to_train
            assert throughput > args.throughput
            if args.loss:
                assert loss < args.loss

    if args.save_final_full_model:
        # saves full model at the end

        base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
        out_path = os.path.join(args.model_dir, base_path)

        if smp.rdp_rank() == 0:
            save(
                out_path,
                model,
                optimizer,
                lr_scheduler,
                model_config,
                num_params,
                total_steps,
                -1,
                args,
                partial=False,
                translate_to_hf=smp.tp_size() > 1,
                seq_length=args.max_context_width,
            )

    smp.barrier()
    if smp.rank() == 0:
        print("SMP training finished successfully")
Пример #9
0
 get_hard_recons = vae.decode
 
 opt = Adam(vae.parameters(), lr = LEARNING_RATE)
 sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE)
 
 logger.debug(f"args.local_rank : {args.local_rank}")
 if args.local_rank is not None:
     torch.cuda.set_device(args.local_rank)
 else:
     torch.cuda.set_device(0)
 
 if args.multigpus_distributed:
     vae.cuda(args.local_rank)
     
     if args.model_parallel:
         vae = smp.DistributedModel(vae)
         args.scaler = smp.amp.GradScaler()
         opt = smp.DistributedOptimizer(opt)
         if args.partial_checkpoint:
             args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
             vae.load_state_dict(args.checkpoint["model_state_dict"])
             opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
         elif args.full_checkpoint:
             args.checkpoint = smp.load(args.full_checkpoint, partial=False)
             vae.load_state_dict(args.checkpoint["model_state_dict"])
             opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
     else:
         
         vae = vae.cuda()
 else:
     vae = vae.cuda()
Пример #10
0
def init_train():
    """
    Train the PyTorch model
    """

    cat_mask = [
        False, True, True, True, True, False, True, True, True, True, True,
        False, False, False, False, False, False, False
    ]
    train_ds = CsvDatasetSimple(args.train)
    test_ds = CsvDatasetSimple(args.test)

    batch_size = args.batch_size
    epochs = args.epochs
    learning_rate = args.learning_rate

    logger.info("batch_size = {}, epochs = {}, learning rate = {}".format(
        batch_size, epochs, learning_rate))

    # smdistributed: initialize the backend
    smp.init()

    # smdistributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")

    # smdistributed: Download only on a single process per instance.
    # When this is not present, the file is corrupted by multiple processes trying
    # to download and extract at the same time
    #dataset = datasets.MNIST("../data", train=True, download=False)
    dataset = train_ds

    # smdistributed: Shard the dataset based on data-parallel ranks
    if smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        dataset = SplitDataset(dataset, partitions=partitions_dict)
        dataset.select(f"{smp.dp_rank()}")

    # smdistributed: Set drop_last=True to ensure that batch size is always divisible
    # by the number of microbatches
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=64,
                                               drop_last=True)

    model = TabularNet(n_cont=9,
                       n_cat=9,
                       cat_mask=cat_mask,
                       cat_dim=[
                           0, 2050, 13, 5, 366, 0, 50000, 50000, 50000, 50000,
                           50, 0, 0, 0, 0, 0, 0, 0
                       ],
                       y_min=0.,
                       y_max=1.)

    logger.debug(model)

    optimizer = optim.Adadelta(model.parameters(), lr=4.0)

    # SMP: Instantiate DistributedModel object using the model.
    # This handles distributing the model among multiple ranks
    # behind the scenes
    # If horovod is enabled this will do an overlapping_all_reduce by
    # default.

    # smdistributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    model = smp.DistributedModel(model)

    optimizer = smp.DistributedOptimizer(optimizer)

    train(model, device, train_loader, optimizer)

    torch.save(model.state_dict(), args.model_dir + "/model.pth")
def main():

    model_args, data_args, training_args, smp_args = parse_args()
    model, tokenizer = initialize_model_and_tokenizer(model_args)

    # Get datasets
    train_dataset, eval_dataset = Preprocess.datasets(model_args, data_args,
                                                      training_args)

    if is_sagemaker_mp_enabled():
        initialize_smp(smp_args, training_args)

        torch.set_default_dtype(torch.float32)

        num_params = print_num_parameters(model)

        # smdistributed: Set the device to the GPU ID used by the current process.
        # Input tensors should be transferred to this device.
        torch.cuda.set_device(smp.local_rank())
        device = torch.device("cuda")

        if not training_args.same_seed:
            # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
            set_seed(training_args.seed + smp.tp_rank())

        model = smp.DistributedModel(model,
                                     trace_device=smp_args.trace_device,
                                     gradient_as_bucket_view=True)

        torch.set_default_dtype(torch.float32)

        iter_model = model
        # Build parameter groups (weight decay and non-decay).
        while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
            iter_model = iter_model.module

        param_groups = get_param_groups_by_weight_decay(iter_model)

        if training_args.use_adamw > 0:
            optimizer = training_args.AdamW(
                param_groups,
                betas=(training_args.beta1, training_args.beta2),
                lr=training_args.lr,
                weight_decay=training_args.weight_decay,
            )
        else:
            optimizer = optim.Adam(
                param_groups,
                betas=(training_args.beta1, training_args.beta2),
                lr=training_args.lr,
                weight_decay=training_args.weight_decay,
            )

        optimizer = smp.DistributedOptimizer(optimizer)
        lr_scheduler = get_learning_rate_scheduler(optimizer, training_args)

        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

        # Initialize Trainer instance

        trainer = SMPTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=eval_dataset if training_args.do_eval else None,
            tokenizer=tokenizer,
            data_collator=default_data_collator,
        )

        trainer.train_smp(
            model,
            optimizer,
            lr_scheduler,
            start_train_path_index,
            start_batch_index,
            num_params,
            total_steps,
            training_args,
            prescaled_batch=smp_args.prescaled_batch,
        )
Пример #12
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")

    args.rank = -1
    args.world_size = 1

    if args.model_parallel:
        args.deepspeed = False
        cfg = {
            "microbatches": args.num_microbatches,
            "placement_strategy": args.placement_strategy,
            "pipeline": args.pipeline,
            "optimize": args.optimize,
            "partitions": args.num_partitions,
            "horovod": args.horovod,
            "ddp": args.ddp,
        }

        smp.init(cfg)
        torch.cuda.set_device(smp.local_rank())
        args.rank = smp.dp_rank()
        args.world_size = smp.size()
    else:
        # initialize deepspeed
        print(f"args.deepspeed : {args.deepspeed}")
        deepspeed_utils.init_deepspeed(args.deepspeed)
        if deepspeed_utils.is_root_worker():
            args.rank = 0

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed + args.rank)
        np.random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size)

    cudnn.deterministic = True

    if cudnn.deterministic:
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True}

    device = torch.device("cuda")

    logger.debug(f"args.image_folder : {args.image_folder}")
    logger.debug(f"args.rank : {args.rank}")

    ## SageMaker
    try:
        if os.environ.get('SM_MODEL_DIR') is not None:
            args.model_dir = os.environ.get('SM_MODEL_DIR')
            #             args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR')
            args.image_folder = os.environ.get('SM_CHANNEL_TRAINING')
    except:
        logger.debug("not SageMaker")
        pass

    IMAGE_SIZE = args.image_size
    IMAGE_PATH = args.image_folder

    EPOCHS = args.EPOCHS
    BATCH_SIZE = args.BATCH_SIZE
    LEARNING_RATE = args.LEARNING_RATE
    LR_DECAY_RATE = args.LR_DECAY_RATE

    NUM_TOKENS = args.NUM_TOKENS
    NUM_LAYERS = args.NUM_LAYERS
    NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS
    SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS
    EMB_DIM = args.EMB_DIM
    HID_DIM = args.HID_DIM
    KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT

    STARTING_TEMP = args.STARTING_TEMP
    TEMP_MIN = args.TEMP_MIN
    ANNEAL_RATE = args.ANNEAL_RATE

    NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE

    #     transform = Compose(
    #         [
    #             RandomResizedCrop(args.image_size, args.image_size),
    #             OneOf(
    #                 [
    #                     IAAAdditiveGaussianNoise(),
    #                     GaussNoise(),
    #                 ],
    #                 p=0.2
    #             ),
    #             VerticalFlip(p=0.5),
    #             OneOf(
    #                 [
    #                     MotionBlur(p=.2),
    #                     MedianBlur(blur_limit=3, p=0.1),
    #                     Blur(blur_limit=3, p=0.1),
    #                 ],
    #                 p=0.2
    #             ),
    #             OneOf(
    #                 [
    #                     CLAHE(clip_limit=2),
    #                     IAASharpen(),
    #                     IAAEmboss(),
    #                     RandomBrightnessContrast(),
    #                 ],
    #                 p=0.3
    #             ),
    #             HueSaturationValue(p=0.3),
    # #             Normalize(
    # #                 mean=[0.485, 0.456, 0.406],
    # #                 std=[0.229, 0.224, 0.225],
    # #             )
    #         ],
    #         p=1.0
    #     )

    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])

    sampler = None
    dl = None

    # data
    logger.debug(f"IMAGE_PATH : {IMAGE_PATH}")
    #     ds = AlbumentationImageDataset(
    #         IMAGE_PATH,
    #         transform=transform,
    #         args=args
    #     )
    ds = ImageFolder(
        IMAGE_PATH,
        transform=transform,
    )

    if args.model_parallel and (args.ddp
                                or args.horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        ds = SplitDataset(ds, partitions=partitions_dict)
        ds.select(f"{smp.dp_rank()}")

    dl = DataLoader(ds,
                    BATCH_SIZE,
                    shuffle=True,
                    drop_last=args.model_parallel,
                    **args.kwargs)

    vae_params = dict(image_size=IMAGE_SIZE,
                      num_layers=NUM_LAYERS,
                      num_tokens=NUM_TOKENS,
                      codebook_dim=EMB_DIM,
                      hidden_dim=HID_DIM,
                      num_resnet_blocks=NUM_RESNET_BLOCKS)

    vae = DiscreteVAE(**vae_params,
                      smooth_l1_loss=SMOOTH_L1_LOSS,
                      kl_div_loss_weight=KL_LOSS_WEIGHT).to(device)
    # optimizer

    opt = Adam(vae.parameters(), lr=LEARNING_RATE)
    sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE)

    if args.model_parallel:
        import copy
        dummy_codebook = copy.deepcopy(vae.codebook)
        dummy_decoder = copy.deepcopy(vae.decoder)

        vae = smp.DistributedModel(vae)
        scaler = smp.amp.GradScaler()
        opt = smp.DistributedOptimizer(opt)

        if args.partial_checkpoint:
            args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
        elif args.full_checkpoint:
            args.checkpoint = smp.load(args.full_checkpoint, partial=False)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])

    assert len(ds) > 0, 'folder does not contain any images'

    if (not args.model_parallel) and args.rank == 0:
        print(f'{len(ds)} images found for training')

        # weights & biases experiment tracking

        #         import wandb

        model_config = dict(num_tokens=NUM_TOKENS,
                            smooth_l1_loss=SMOOTH_L1_LOSS,
                            num_resnet_blocks=NUM_RESNET_BLOCKS,
                            kl_loss_weight=KL_LOSS_WEIGHT)

#         run = wandb.init(
#             project = 'dalle_train_vae',
#             job_type = 'train_model',
#             config = model_config
#         )

    def save_model(path):
        if not args.rank == 0:
            return

        save_obj = {'hparams': vae_params, 'weights': vae.state_dict()}

        torch.save(save_obj, path)

    # distribute with deepspeed
    if not args.model_parallel:
        deepspeed_utils.check_batch_size(BATCH_SIZE)
        deepspeed_config = {'train_batch_size': BATCH_SIZE}

        (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute(
            args=args,
            model=vae,
            optimizer=opt,
            model_parameters=vae.parameters(),
            training_data=ds if args.deepspeed else dl,
            lr_scheduler=sched,
            config_params=deepspeed_config,
        )

    try:
        # Rubik: Define smp.step. Return any tensors needed outside.
        @smp.step
        def train_step(vae, images, temp):
            #             logger.debug(f"args.amp : {args.amp}")
            with autocast(enabled=(args.amp > 0)):
                loss, recons = vae(images,
                                   return_loss=True,
                                   return_recons=True,
                                   temp=temp)

            scaled_loss = scaler.scale(loss) if args.amp else loss
            vae.backward(scaled_loss)
            #             torch.nn.utils.clip_grad_norm_(vae.parameters(), 5)
            return loss, recons

        @smp.step
        def get_codes_step(vae, images, k):
            images = images[:k]
            logits = vae.forward(images, return_logits=True)
            codebook_indices = logits.argmax(dim=1).flatten(1)
            return codebook_indices

        def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices):
            from functools import partial
            for module in dummy_codebook.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            image_embeds = dummy_codebook.forward(codebook_indices)
            b, n, d = image_embeds.shape
            h = w = int(sqrt(n))

            image_embeds = rearrange(image_embeds,
                                     'b (h w) d -> b d h w',
                                     h=h,
                                     w=w)
            for module in dummy_decoder.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            hard_recons = dummy_decoder.forward(image_embeds)
            return hard_recons

    except:
        pass

    # starting temperature

    global_step = 0
    temp = STARTING_TEMP

    for epoch in range(EPOCHS):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(dl), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        vae.train()
        start = time.time()

        for i, (images, _) in enumerate(dl):
            images = images.to(device, non_blocking=True)
            opt.zero_grad()

            if args.model_parallel:
                loss, recons = train_step(vae, images, temp)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()
                recons = recons.reduce_mean()
            else:
                loss, recons = distr_vae(images,
                                         return_loss=True,
                                         return_recons=True,
                                         temp=temp)

            if (not args.model_parallel) and args.deepspeed:
                # Gradients are automatically zeroed after the step
                distr_vae.backward(loss)
                distr_vae.step()
            elif args.model_parallel:
                if args.amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(vae.local_parameters())) > 0:
                        opt.step()
            else:
                loss.backward()
                opt.step()

            logs = {}

            if i % 10 == 0:
                if args.rank == 0:
                    #                 if deepspeed_utils.is_root_worker():
                    k = NUM_IMAGES_SAVE

                    with torch.no_grad():
                        if args.model_parallel:
                            model_dict = vae.state_dict()
                            model_dict_updated = {}
                            for key, val in model_dict.items():
                                if "decoder" in key:
                                    key = key.replace("decoder.", "")
                                elif "codebook" in key:
                                    key = key.replace("codebook.", "")
                                model_dict_updated[key] = val

                            dummy_decoder.load_state_dict(model_dict_updated,
                                                          strict=False)
                            dummy_codebook.load_state_dict(model_dict_updated,
                                                           strict=False)
                            codes = get_codes_step(vae, images, k)
                            codes = codes.reduce_mean().to(torch.long)
                            hard_recons = hard_recons_step(
                                dummy_decoder, dummy_codebook, codes)
                        else:
                            codes = vae.get_codebook_indices(images[:k])
                            hard_recons = vae.decode(codes)

                    images, recons = map(lambda t: t[:k], (images, recons))
                    images, recons, hard_recons, codes = map(
                        lambda t: t.detach().cpu(),
                        (images, recons, hard_recons, codes))
                    images, recons, hard_recons = map(
                        lambda t: make_grid(t.float(),
                                            nrow=int(sqrt(k)),
                                            normalize=True,
                                            range=(-1, 1)),
                        (images, recons, hard_recons))

#                     logs = {
#                         **logs,
#                         'sample images':        wandb.Image(images, caption = 'original images'),
#                         'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
#                         'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
#                         'codebook_indices':     wandb.Histogram(codes),
#                         'temperature':          temp
#                     }

                if args.model_parallel:
                    filename = f'{args.model_dir}/vae.pt'
                    if smp.dp_rank == 0:
                        if args.save_full_model:
                            model_dict = vae.state_dict()
                            opt_dict = opt.state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=False,
                            )
                        else:
                            model_dict = vae.local_state_dict()
                            opt_dict = opt.local_state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=True,
                            )
                    smp.barrier()

                else:
                    save_model(f'{args.model_dir}/vae.pt')
    #                     wandb.save(f'{args.model_dir}/vae.pt')

    # temperature anneal

                temp = max(temp * math.exp(-ANNEAL_RATE * global_step),
                           TEMP_MIN)

                # lr decay

                sched.step()

            # Collective loss, averaged
            if args.model_parallel:
                avg_loss = loss.detach().clone()
                #                 print("args.world_size : {}".format(args.world_size))
                avg_loss /= args.world_size

            else:
                avg_loss = deepspeed_utils.average_all(loss)

            if args.rank == 0:
                if i % 100 == 0:
                    lr = sched.get_last_lr()[0]
                    print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},')

                    logs = {
                        **logs, 'epoch': epoch,
                        'iter': i,
                        'loss': avg_loss.item(),
                        'lr': lr
                    }

#                 wandb.log(logs)
            global_step += 1

            if args.rank == 0:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                # Measure accuracy
                #                 prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

                # to_python_float incurs a host<->device sync
                losses.update(util.to_python_float(loss), images.size(0))
                #                 top1.update(util.to_python_float(prec1), images.size(0))
                #                 top5.update(util.to_python_float(prec5), images.size(0))

                # Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - start) / args.log_interval)
                end = time.time()

                print(
                    'Epoch: [{0}][{1}/{2}] '
                    'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, '
                    'Train_Speed={3:.3f} ({4:.3f}), '
                    'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format(
                        epoch,
                        i,
                        len(dl),
                        args.world_size * BATCH_SIZE / batch_time.val,
                        args.world_size * BATCH_SIZE / batch_time.avg,
                        batch_time=batch_time,
                        loss=losses))

#         if deepspeed_utils.is_root_worker():
# save trained model to wandb as an artifact every epoch's end

#             model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#             model_artifact.add_file(f'{args.model_dir}/vae.pt')
#             run.log_artifact(model_artifact)

    if args.rank == 0:
        #     if deepspeed_utils.is_root_worker():
        # save final vae and cleanup
        if args.model_parallel:
            logger.debug('save model_parallel')
        else:
            save_model(os.path.join(args.model_dir, 'vae-final.pt'))


#         wandb.save(f'{args.model_dir}/vae-final.pt')

#         model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#         model_artifact.add_file(f'{args.model_dir}/vae-final.pt')
#         run.log_artifact(model_artifact)

#         wandb.finish()

    if args.model_parallel:
        if args.assert_losses:
            if args.horovod or args.ddp:
                # SM Distributed: If using data parallelism, gather all losses across different model
                # replicas and check if losses match.

                losses = smp.allgather(loss, smp.DP_GROUP)
                for l in losses:
                    print(l)
                    assert math.isclose(l, losses[0])

                assert loss < 0.18
            else:
                assert loss < 0.08

        smp.barrier()
        print("SMP training finished successfully")