Beispiel #1
0
def get_vae(args):
    vae = DiscreteVAE(image_size=args.size,
                      num_layers=args.vae_layers,
                      num_tokens=8192,
                      codebook_dim=args.codebook_dims,
                      num_resnet_blocks=9,
                      hidden_dim=128,
                      temperature=args.temperature)

    if args.vae is not None and os.path.isfile(args.vae):
        print(f"loading state dict from {args.vae}")
        vae.load_state_dict(torch.load(args.vae))

    vae.to(args.device)

    return vae
Beispiel #2
0
IMAGE_SIZE = args.image_size
VAE_CLASS = args.vae_class
if VAE_CLASS == 'VQGAN1024':
    vae = VQGanVAE1024()
elif VAE_CLASS == 'VQGAN16384':
    vae = VQGanVAE16384()
elif VAE_CLASS == 'VQGAN_CUSTOM':
    vae = VQGanVAECustom()
elif VAE_CLASS == 'DALLE':
    vae = OpenAIDiscreteVAE()
elif VAE_CLASS == 'DALLE_TRAIN':
    VAE_PATH = args.vae_path
    vae_path = Path(VAE_PATH)
    loaded_obj = torch.load(str(vae_path), map_location='cuda')
    vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']
    vae = DiscreteVAE(**vae_params)
    vae.load_state_dict(weights)
    vae.to('cuda')

filenames = os.listdir(args.target)

for filename in tqdm(filenames):
    TARGET_IMG_PATH = args.target + '/' + filename
    TARGET_SAVE_PATH = args.target + '/output/'
    filename = filename.split('.')[0]

    img = PIL.Image.open(TARGET_IMG_PATH)

    composed = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from dalle_pytorch import DiscreteVAE

imgSize = 256
load_epoch = 280

vae = DiscreteVAE(image_size=imgSize,
                  num_layers=3,
                  channels=3,
                  num_tokens=2048,
                  codebook_dim=1024,
                  hidden_dim=128)

vae_dict = torch.load("./models/dvae-" + str(load_epoch) + ".pth")
vae.load_state_dict(vae_dict)
vae.cuda()

batchSize = 12
n_epochs = 500
log_interval = 20
#images = torch.randn(4, 3, 256, 256)

t = transforms.Compose([
    transforms.Resize(imgSize),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
Beispiel #4
0
TEXTSEQLEN = 80

BATCH_SIZE = 1

TRAIN_BATCHES = 100

#https://github.com/lucidrains/DALLE-pytorch/issues/33
#Edit: And yup, you need to reserve 0 for padding and 1 for , so add 2 to your encoded text ids!

vae = DiscreteVAE(
    image_size=IMAGE_SIZE,
    num_layers=
    3,  # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens=
    8192,  # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim=512,  # codebook dimension
    hidden_dim=64,  # hidden dimension
    num_resnet_blocks=1,  # number of resnet blocks
    temperature=
    0.9,  # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through=
    False,  # straight-through for gumbel softmax. unclear if it is better one way or the other
).cuda()

vae.load_state_dict(torch.load("Vae-small.pth"))

dalle = DALLE(
    dim=1024,
    vae=
    vae,  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens=NUM_TOKENS,  # vocab size for text
    text_seq_len=TEXTSEQLEN,  # text sequence length
)

dl = DataLoader(ds, BATCH_SIZE, shuffle = True)

vae_params = dict(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    codebook_dim = EMB_DIM,
    hidden_dim   = HID_DIM,
    num_resnet_blocks = NUM_RESNET_BLOCKS
)

vae = DiscreteVAE(
    **vae_params,
    smooth_l1_loss = SMOOTH_L1_LOSS,
    kl_div_loss_weight = KL_LOSS_WEIGHT
).cuda()


assert len(ds) > 0, 'folder does not contain any images'
print(f'{len(ds)} images found for training')

def save_model(path):
    save_obj = {
        'hparams': vae_params,
        'weights': vae.state_dict()
    }

    torch.save(save_obj, path)
Beispiel #6
0
import torch
from dalle_pytorch import DiscreteVAE
import json

vae = DiscreteVAE(
    image_size=256,
    num_layers=3,
    num_tokens=8192,
    codebook_dim=512,
    hidden_dim=64,
    num_resnet_blocks=1,
    temperature=0.9,
    straight_through=False,
)

images = torch.randn(2, 3, 256, 256)
idx = vae.get_codebook_indices(images)
print(idx)
print(idx.size())
Beispiel #7
0
def main(args):
    # constants

    VAE_PATH = args.vae_path
    DALLE_PATH = args.dalle_path
    RESUME = exists(DALLE_PATH)

    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size

    LEARNING_RATE = args.learning_rate
    GRAD_CLIP_NORM = args.clip_grad_norm
    LR_DECAY = args.lr_decay

    MODEL_DIM = args.dim
    TEXT_SEQ_LEN = args.text_seq_len
    DEPTH = args.depth
    HEADS = args.heads
    DIM_HEAD = args.dim_head
    REVERSIBLE = args.reversible
    LOSS_IMG_WEIGHT = args.loss_img_weight

    ATTN_TYPES = tuple(args.attn_types.split(','))

    DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'

    # initialize distributed backend

    # initialize distributed backend
    if args.sagemakermp:
        args.deepspeed = False
        using_deepspeed = False
    else:
        args.deepspeed = True
    
    distr_backend = distributed_utils.set_backend_from_args(args)
    distr_backend.initialize(args)

    if args.sagemakermp:
        args = smp_init(args)
        distributed_utils.using_backend(distributed_utils.SageMakerMPBackend)
    else:
        using_deepspeed = \
            distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) 
        args.rank = int(os.environ.get('RANK'))
        args.world_size = int(os.environ.get('WORLD_SIZE'))
        args.local_rank = int(os.environ.get('LOCAL_RANK'))
        args.global_rank = args.rank
    
    logger.debug(f"using_deepspeed : {using_deepspeed}")
    
    logger.debug(
        f"args.local_rank : {args.local_rank}, args.rank : {args.rank}")
    

    # tokenizer
    logger.debug(f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}")
    if exists(args.bpe_path):
        klass = HugTokenizer if args.hug else YttmTokenizer
        tokenizer = klass(args.bpe_path)
    elif args.chinese:
        tokenizer = ChineseTokenizer()
    else:
        tokenizer = SimpleTokenizer()

    # reconstitute vae

    if RESUME:
        dalle_path = Path(DALLE_PATH)
        if using_deepspeed:
            cp_dir = cp_path_to_dir(dalle_path, 'ds')
            assert cp_dir.is_dir(), \
                f'DeepSpeed checkpoint directory {cp_dir} not found'
            dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME
        else:
            assert dalle_path.exists(), 'DALL-E model file does not exist'
        loaded_obj = torch.load(str(dalle_path), map_location='cpu')

        dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[
            'vae_params'], loaded_obj['weights']

        if vae_params is not None:
            vae = DiscreteVAE(**vae_params)
        else:
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        dalle_params = dict(**dalle_params)
        IMAGE_SIZE = vae.image_size
    else:
        if exists(VAE_PATH):
            vae_path = Path(VAE_PATH)
            assert vae_path.exists(), 'VAE model file does not exist'
            assert not vae_path.is_dir(), \
                ('Cannot load VAE model from directory; please use a '
                'standard *.pt checkpoint. '
                'Currently, merging a DeepSpeed-partitioned VAE into a DALLE '
                'model is not supported.')

            loaded_obj = torch.load(str(vae_path))

            vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

            vae = DiscreteVAE(**vae_params)
            vae.load_state_dict(weights)
        else:
            if args.rank == 0:
                # if distr_backend.is_root_worker():
                print('using pretrained VAE for encoding images to tokens')
            vae_params = None
            logger.debug(f"************* args.taming : {args.taming}")
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        IMAGE_SIZE = vae.image_size

        dalle_params = dict(
            num_text_tokens=tokenizer.vocab_size,
            text_seq_len=TEXT_SEQ_LEN,
            dim=MODEL_DIM,
            depth=DEPTH,
            heads=HEADS,
            dim_head=DIM_HEAD,
            reversible=REVERSIBLE,
            loss_img_weight=LOSS_IMG_WEIGHT,
            attn_types=ATTN_TYPES,
        )

    # configure OpenAI VAE for float16s

    if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
        vae.enc.blocks.output.conv.use_float16 = True

    # create dataset and dataloader

    is_shuffle = not distributed_utils.using_backend(
        distributed_utils.HorovodBackend)

    ds = TextImageDataset(
        args.image_text_folder,
        text_len=TEXT_SEQ_LEN,
        image_size=IMAGE_SIZE,
        resize_ratio=args.resize_ratio,
        truncate_captions=args.truncate_captions,
        tokenizer=tokenizer,
        shuffle=is_shuffle,
    )

    assert len(ds) > 0, 'dataset is empty'
    # if distr_backend.is_root_worker():
    if args.rank == 0:
        print(f'{len(ds)} image-text pairs found for training')

    if not is_shuffle:
        data_sampler = torch.utils.data.distributed.DistributedSampler(
            ds, num_replicas=args.world_size, rank=args.rank)
    elif args.sagemakermp:
        args.ds = ds
        ds = split_dataset(args)
        data_sampler = None
    else:
        data_sampler = None
        
        
    print(f"data_sampler : {data_sampler}")

    # uncorrectable NVLink error was detected during the execution  --> remove
    kwargs = {'num_workers': args.num_worker, 'pin_memory': True}
    dl = DataLoader(ds,
                    batch_size=BATCH_SIZE,
                    shuffle=is_shuffle,
                    drop_last=True,
                    sampler=data_sampler,
                    **kwargs
                   )
    
    logger.info("Processes {}/{} ({:.0f}%) of train data".format(
    len(dl.sampler), len(dl.dataset),
    100. * len(dl.sampler) / len(dl.dataset)))

    # initialize DALL-E
    dalle = DALLE(vae=vae, **dalle_params)
    if not using_deepspeed:
        if args.fp16:
            dalle = dalle.half()
        dalle = dalle.cuda()

    if RESUME and not using_deepspeed:
        dalle.load_state_dict(weights)

    # optimizer
    opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
    

    if LR_DECAY:
        scheduler = ReduceLROnPlateau(
            opt,
            mode="min",
            factor=0.5,
            patience=10,
            cooldown=10,
            min_lr=1e-6,
            verbose=True,
        )
    # if distr_backend.is_root_worker():
    if args.global_rank == 0:
        # experiment tracker

        model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD)
        
        
        logger.debug(f"args.wandb_name : {args.wandb_name}, RESUME : {RESUME}")
        
        run = wandb.init(
            project=args.wandb_name,  # 'dalle_train_transformer' by default
            resume=RESUME,
            config=model_config,
        )

    # distribute
    distr_backend.check_batch_size(BATCH_SIZE)
    deepspeed_config = {
        'train_batch_size': BATCH_SIZE,
        'gradient_clipping': GRAD_CLIP_NORM,
        'fp16': {
            'enabled': args.fp16,
        },
    }
    (distr_dalle, distr_opt, distr_dl,
     distr_scheduler) = distr_backend.distribute(
         args=args,
         model=dalle,
         optimizer=opt,
         model_parameters=get_trainable_params(dalle),
         training_data=ds if using_deepspeed else dl,
         lr_scheduler=scheduler if LR_DECAY else None,
         config_params=deepspeed_config,
     )
    avoid_model_calls = using_deepspeed and args.fp16

    if args.sagemakermp:
        args.distr_dalle = smp.DistributedModel(distr_dalle)
        args.scaler = smp.amp.GradScaler()
        args.distr_opt = smp.DistributedOptimizer(distr_opt)
    
    if RESUME and using_deepspeed:
        distr_dalle.load_checkpoint(str(cp_dir))
        


    # training

    for epoch in range(EPOCHS):
        logger.debug(f"********* epoch : {epoch} **********")
        if data_sampler:
            data_sampler.set_epoch(epoch)
            
        for i, (text, images) in enumerate(distr_dl):
            if args.fp16:
                images = images.half()
                
            text, images = map(lambda t: t.cuda(), (text, images))
            
            if args.sagemakermp:
                args.distr_opt.zero_grad()
                
                loss = train_step(args, text, images, return_loss=True)
                loss = loss.reduce_mean()

            else:  
                loss = distr_dalle(text, images, return_loss=True, args=args)

            if using_deepspeed:
                distr_dalle.backward(loss)
                distr_dalle.step()
                # Gradients are automatically zeroed after the step
            elif args.sagemakermp:
                if args.amp:
                    scaler.step(args.distr_opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(args.distr_dalle.local_parameters())) > 0:
                        args.distr_opt.step()
            else:
                loss.backward()
                clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM)
                distr_opt.step()
                distr_opt.zero_grad()

            # Collective loss, averaged
            avg_loss = distr_backend.average_all(loss)

            log = {}

            # if i % 10 == 0 and distr_backend.is_root_worker():
            if i % 10 == 0 and args.rank == 0:
                print(epoch, i, f'loss - {avg_loss.item()}')

                log = {
                    **log, 'epoch': epoch,
                    'iter': i,
                    'loss': avg_loss.item()
                }

            if i % 100 == 0:
                # if distr_backend.is_root_worker():
                if args.rank == 0:
                    sample_text = text[:1]
                    token_list = sample_text.masked_select(
                        sample_text != 0).tolist()
                    decoded_text = tokenizer.decode(token_list)
                    
                    logger.debug(f"******* avoid_model_calls : {avoid_model_calls}")
                    if not avoid_model_calls:
                        # CUDA index errors when we don't guard this
                        image = dalle.generate_images(
                            text[:1], filter_thres=0.9)  # topk sampling at 0.9

                    wandb.save(f'./dalle.pt')

                    log = {
                        **log,
                    }
                    if not avoid_model_calls:
                        log['image'] = wandb.Image(image, caption=decoded_text)

                args.distr_dalle = distr_dalle
                args.dalle_params = dalle_params
                args.vae_params = vae_params
                args.using_deepspeed = using_deepspeed
                args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME

                save_model(args, f'{args.model_dir}/dalle.pt')

            # if distr_backend.is_root_worker():
            if args.rank == 0:
                wandb.log(log)
                
#             text, images = prefetcher.next()

        if LR_DECAY and not using_deepspeed:
            # Scheduler is automatically progressed after the step when
            # using DeepSpeed.
            distr_scheduler.step(loss)

        # if distr_backend.is_root_worker():
        if args.global_rank == 0:
            # save trained model to wandb as an artifact every epoch's end

            model_artifact = wandb.Artifact('trained-dalle',
                                            type='model',
                                            metadata=dict(model_config))
            # model_artifact.add_file('dalle.pt')
            run.log_artifact(model_artifact)

    args.distr_dalle = distr_dalle
    args.dalle_params = dalle_params
    args.vae_params = vae_params
    args.using_deepspeed = using_deepspeed
    args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME

    save_model(args, f'{args.model_dir}/dalle-final.pt')

    # if distr_backend.is_root_worker():
    if args.global_rank == 0:
        wandb.save('./dalle-final.pt')
        model_artifact = wandb.Artifact('trained-dalle',
                                        type='model',
                                        metadata=dict(model_config))
        # model_artifact.add_file('dalle-final.pt')
        run.log_artifact(model_artifact)

        wandb.finish()
Beispiel #8
0
name = "vae-cdim256"
loadfn = "./models/dalle_" + name + "-" + str(dalle_epoch) + ".pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tf = transforms.Compose([
    #transforms.Resize(imgSize),
    #transforms.RandomHorizontalFlip(),
    #transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))  #(0.267, 0.233, 0.234))
])

vae = DiscreteVAE(image_size=256,
                  num_layers=3,
                  num_tokens=2048,
                  codebook_dim=256,
                  hidden_dim=128,
                  temperature=0.9)

# load pretrained vae

vae_dict = torch.load("./models/" + vaename + "-" + str(load_epoch) + ".pth")
vae.load_state_dict(vae_dict)
vae.to(device)

dalle = DALLE(
    dim=256,  #512,
    vae=
    vae,  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens=10000,  # vocab size for text
    text_seq_len=256,  # text sequence length
Beispiel #9
0
def main(args):
    # constants
    print(f"torch.cuda.nccl.version() : {torch.cuda.nccl.version()}")
    VAE_PATH = args.vae_path
    DALLE_PATH = args.dalle_path
    RESUME = exists(DALLE_PATH)

    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size

    LEARNING_RATE = args.learning_rate
    GRAD_CLIP_NORM = args.clip_grad_norm
    LR_DECAY = args.lr_decay
    SAVE_EVERY_N_STEPS = args.save_every_n_steps
    MODEL_DIM = args.dim
    TEXT_SEQ_LEN = args.text_seq_len
    DEPTH = args.depth
    HEADS = args.heads
    DIM_HEAD = args.dim_head
    REVERSIBLE = args.reversible
    LOSS_IMG_WEIGHT = args.loss_img_weight

    ATTN_TYPES = tuple(args.attn_types.split(','))

    DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'
    DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name

    # initialize distributed backend

    args.deepspeed = True

    distr_backend = distributed_utils.set_backend_from_args(args)
    distr_backend.initialize(args)

    using_deepspeed = \
        distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)
    args.rank = int(os.environ.get('RANK'))
    args.world_size = int(os.environ.get('WORLD_SIZE'))
    args.local_rank = int(os.environ.get('LOCAL_RANK'))
    args.global_rank = args.rank

    logger.debug(f"using_deepspeed : {using_deepspeed}")

    logger.debug(
        f"args.local_rank : {args.local_rank}, args.rank : {args.rank}")
    print(
        f"********* torch.distributed.get_rank() : {torch.distributed.get_rank()}"
    )
    # tokenizer
    logger.debug(
        f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}"
    )

    #     if args.local_rank == 0 or args.local_rank == 1:
    # #         print(f"args.job_name : {args.job_name}")
    #         gpu_mon_thread = aws_util.GPUMon(device_index= args.local_rank, job_name=args.job_name)
    #         gpu_mon_thread.start()

    if exists(args.bpe_path):
        klass = HugTokenizer if args.hug else YttmTokenizer
        tokenizer = klass(args.bpe_path)
    elif args.chinese:
        tokenizer = ChineseTokenizer()
    else:
        tokenizer = SimpleTokenizer()

    # reconstitute vae

    if RESUME:
        dalle_path = Path(DALLE_PATH)
        if using_deepspeed:
            cp_dir = cp_path_to_dir(dalle_path, 'ds')
            assert cp_dir.is_dir(), \
                f'DeepSpeed checkpoint directory {cp_dir} not found'
            dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME
        else:
            assert dalle_path.exists(), 'DALL-E model file does not exist'
        loaded_obj = torch.load(str(dalle_path), map_location='cpu')

        dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[
            'vae_params'], loaded_obj['weights']

        if vae_params is not None:
            vae = DiscreteVAE(**vae_params)
        else:
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        dalle_params = dict(**dalle_params)
        IMAGE_SIZE = vae.image_size
    else:
        if exists(VAE_PATH):
            vae_path = Path(VAE_PATH)
            assert vae_path.exists(), 'VAE model file does not exist'
            assert not vae_path.is_dir(), \
                ('Cannot load VAE model from directory; please use a '
                 'standard *.pt checkpoint. '
                 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE '
                 'model is not supported.')

            loaded_obj = torch.load(str(vae_path))

            vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

            vae = DiscreteVAE(**vae_params)
            vae.load_state_dict(weights)
        else:
            if args.rank == 0:
                # if distr_backend.is_root_worker():
                print('using pretrained VAE for encoding images to tokens')
            vae_params = None
            logger.debug(f"************* args.taming : {args.taming}")
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        IMAGE_SIZE = vae.image_size

        dalle_params = dict(
            num_text_tokens=tokenizer.vocab_size,
            text_seq_len=TEXT_SEQ_LEN,
            dim=MODEL_DIM,
            depth=DEPTH,
            heads=HEADS,
            dim_head=DIM_HEAD,
            reversible=REVERSIBLE,
            loss_img_weight=LOSS_IMG_WEIGHT,
            attn_types=ATTN_TYPES,
        )

    # configure OpenAI VAE for float16s

    if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
        vae.enc.blocks.output.conv.use_float16 = True

    # create dataset and dataloader

    is_shuffle = not distributed_utils.using_backend(
        distributed_utils.HorovodBackend)

    ds = TextImageDataset(
        args.image_text_folder,
        text_len=TEXT_SEQ_LEN,
        image_size=IMAGE_SIZE,
        resize_ratio=args.resize_ratio,
        truncate_captions=args.truncate_captions,
        tokenizer=tokenizer,
        shuffle=is_shuffle,
    )

    assert len(ds) > 0, 'dataset is empty'
    # if distr_backend.is_root_worker():
    if args.rank == 0:
        print(f'{len(ds)} image-text pairs found for training')

    if not is_shuffle:
        data_sampler = torch.utils.data.distributed.DistributedSampler(
            ds, num_replicas=args.world_size, rank=args.rank)
    else:
        data_sampler = None

    kwargs = {'num_workers': args.num_worker, 'pin_memory': True}
    dl = DataLoader(ds,
                    batch_size=BATCH_SIZE,
                    shuffle=is_shuffle,
                    drop_last=True,
                    sampler=data_sampler,
                    **kwargs)

    logger.info("Processes {}/{} ({:.0f}%) of train data".format(
        len(dl.sampler), len(dl.dataset),
        100. * len(dl.sampler) / len(dl.dataset)))

    # initialize DALL-E

    dalle = DALLE(vae=vae, **dalle_params)
    if not using_deepspeed:
        if args.fp16:
            dalle = dalle.half()
        dalle = dalle.cuda()

    if RESUME and not using_deepspeed:
        dalle.load_state_dict(weights)

    # optimizer

    opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)

    if LR_DECAY:
        scheduler = ReduceLROnPlateau(
            opt,
            mode="min",
            factor=0.5,
            patience=10,
            cooldown=10,
            min_lr=1e-6,
            verbose=True,
        )

    # if distr_backend.is_root_worker():
    if args.rank == 0:
        # experiment tracker

        model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD)

        #         wandb_dir = '/tmp/wandb'
        #         if not os.path.exists(wandb_dir):
        #             os.makedirs(wandb_dir)

        run = wandb.init(
            project=args.wandb_name,  # 'dalle_train_transformer' by default
            resume=RESUME,
            config=model_config,
            #             dir=wandb_dir
        )

    # distribute

    distr_backend.check_batch_size(BATCH_SIZE)
    deepspeed_config = {
        'train_batch_size': BATCH_SIZE,
        'gradient_clipping': GRAD_CLIP_NORM,
        'fp16': {
            'enabled': args.fp16,
        },
    }

    (distr_dalle, distr_opt, distr_dl,
     distr_scheduler) = distr_backend.distribute(
         args=args,
         model=dalle,
         optimizer=opt,
         model_parameters=get_trainable_params(dalle),
         training_data=ds if using_deepspeed else dl,
         lr_scheduler=scheduler if LR_DECAY else None,
         config_params=deepspeed_config,
     )
    avoid_model_calls = using_deepspeed and args.fp16

    if RESUME and using_deepspeed:
        distr_dalle.load_checkpoint(str(cp_dir))

    # training

    for epoch in range(EPOCHS):
        logger.debug(f"********* epoch : {epoch} **********")
        if data_sampler:
            data_sampler.set_epoch(epoch)
        for i, (text, images) in enumerate(distr_dl):

            if i % 10 == 0 and args.rank == 0:
                t = time.time()
            if args.fp16:
                images = images.half()
            text, images = map(lambda t: t.cuda(), (text, images))

            loss = distr_dalle(text, images, return_loss=True)

            if using_deepspeed:
                distr_dalle.backward(loss)
                distr_dalle.step()
                # Gradients are automatically zeroed after the step
            else:
                loss.backward()
                clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM)
                distr_opt.step()
                distr_opt.zero_grad()

            # Collective loss, averaged
            avg_loss = distr_backend.average_all(loss)

            log = {}

            # if i % 10 == 0 and distr_backend.is_root_worker():
            if i % 10 == 0 and args.rank == 0:
                print(epoch, i, f'loss - {avg_loss.item()}')

                log = {
                    **log, 'epoch': epoch,
                    'iter': i,
                    'loss': avg_loss.item()
                }
            if i % SAVE_EVERY_N_STEPS == 0:
                args.distr_dalle = distr_dalle
                args.dalle_params = dalle_params
                args.vae_params = vae_params
                args.using_deepspeed = using_deepspeed
                args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME
                save_model(args,
                           f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}")

            if i % 100 == 0:
                # if distr_backend.is_root_worker():
                if args.rank == 0:
                    sample_text = text[:1]
                    token_list = sample_text.masked_select(
                        sample_text != 0).tolist()
                    decoded_text = tokenizer.decode(token_list)

                    if not avoid_model_calls:
                        # CUDA index errors when we don't guard this
                        image = dalle.generate_images(
                            text[:1], filter_thres=0.9)  # topk sampling at 0.9

                    log = {
                        **log,
                    }
                    if not avoid_model_calls:
                        log['image'] = wandb.Image(image, caption=decoded_text)
            if i % 10 == 9 and args.rank == 0:
                sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
                log["sample_per_sec"] = sample_per_sec
                print(epoch, i, f'sample_per_sec - {sample_per_sec}')
            # if distr_backend.is_root_worker():
            if args.rank == 0:
                wandb.log(log)

        if LR_DECAY and not using_deepspeed:
            # Scheduler is automatically progressed after the step when
            # using DeepSpeed.
            distr_scheduler.step(loss)

        args.distr_dalle = distr_dalle
        args.dalle_params = dalle_params
        args.vae_params = vae_params
        args.using_deepspeed = using_deepspeed
        args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME

        save_model(args, f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}")
        #                 sync_local_checkpoints_to_s3(local_path=f'{args.model_dir}', s3_path='s3://lgaivision-coco-usva/Dalle_Model/tmd/')
        sync_local_checkpoints_to_s3(
            args.model_dir,
            os.path.join(args.output_s3, args.job_name + "/temp"))

        # if distr_backend.is_root_worker():
        if args.rank == 0:
            # save trained model to wandb as an artifact every epoch's end

            model_artifact = wandb.Artifact('trained-dalle',
                                            type='model',
                                            metadata=dict(model_config))
            import glob
            print(f"************** file : {glob.glob(args.model_dir+'/*')}")
            try:
                print(f"wandb.run.dir : {wandb.run.dir}")
                print(
                    f"************** file wandb: {glob.glob(wandb.run.dir+'/*')}"
                )
            except:
                pass

            model_artifact.add_file(
                f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}")
            run.log_artifact(model_artifact)

    args.distr_dalle = distr_dalle
    args.dalle_params = dalle_params
    args.vae_params = vae_params
    args.using_deepspeed = using_deepspeed
    args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME
    resource_check(args)
    save_model(args, f"{args.model_dir +'/'+DALLE_OUTPUT_FILE_NAME}")
    if args.rank == 0:
        #         from distutils.dir_util import copy_tree

        #         copy_tree(f'{args.model_dir}', f'{args.model_dir_last}')
        sync_local_checkpoints_to_s3(
            args.model_dir,
            os.path.join(args.output_s3, args.job_name + "/temp"))
        resource_check(args)

    # if distr_backend.is_root_worker():
    if args.rank == 0:
        wandb.save(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}",
                   base_path=args.model_dir)

        model_artifact = wandb.Artifact('trained-dalle',
                                        type='model',
                                        metadata=dict(model_config))
        model_artifact.add_file(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}")
        run.log_artifact(model_artifact)

        wandb.finish()


#     if args.local_rank == 0 or args.local_rank == 1:
#         gpu_mon_thread.kill()

    distributed_utils.backend.local_barrier()
    print("************************ Finished *************************")
Beispiel #10
0
temperature_scheduling = opt.tempsched  #True

name = opt.name  #"v2vae256"

# for continuing training
# set loadfn: path to pretrained model
# start_epoch: start epoch numbering from this
loadfn = opt.loadVAE  #""
start_epoch = opt.start_epoch  #0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vae = DiscreteVAE(image_size=imgSize,
                  num_layers=3,
                  channels=3,
                  num_tokens=2048,
                  codebook_dim=256,
                  hidden_dim=128,
                  temperature=opt.temperature)

if loadfn != "":
    vae_dict = torch.load(loadfn)
    vae.load_state_dict(vae_dict)

vae.to(device)

t = transforms.Compose([
    transforms.Resize(imgSize),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))  #(0.267, 0.233, 0.234))
Beispiel #11
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")

    args.rank = -1
    args.world_size = 1

    if args.model_parallel:
        args.deepspeed = False
        cfg = {
            "microbatches": args.num_microbatches,
            "placement_strategy": args.placement_strategy,
            "pipeline": args.pipeline,
            "optimize": args.optimize,
            "partitions": args.num_partitions,
            "horovod": args.horovod,
            "ddp": args.ddp,
        }

        smp.init(cfg)
        torch.cuda.set_device(smp.local_rank())
        args.rank = smp.dp_rank()
        args.world_size = smp.size()
    else:
        # initialize deepspeed
        print(f"args.deepspeed : {args.deepspeed}")
        deepspeed_utils.init_deepspeed(args.deepspeed)
        if deepspeed_utils.is_root_worker():
            args.rank = 0

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed + args.rank)
        np.random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size)

    cudnn.deterministic = True

    if cudnn.deterministic:
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True}

    device = torch.device("cuda")

    logger.debug(f"args.image_folder : {args.image_folder}")
    logger.debug(f"args.rank : {args.rank}")

    ## SageMaker
    try:
        if os.environ.get('SM_MODEL_DIR') is not None:
            args.model_dir = os.environ.get('SM_MODEL_DIR')
            #             args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR')
            args.image_folder = os.environ.get('SM_CHANNEL_TRAINING')
    except:
        logger.debug("not SageMaker")
        pass

    IMAGE_SIZE = args.image_size
    IMAGE_PATH = args.image_folder

    EPOCHS = args.EPOCHS
    BATCH_SIZE = args.BATCH_SIZE
    LEARNING_RATE = args.LEARNING_RATE
    LR_DECAY_RATE = args.LR_DECAY_RATE

    NUM_TOKENS = args.NUM_TOKENS
    NUM_LAYERS = args.NUM_LAYERS
    NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS
    SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS
    EMB_DIM = args.EMB_DIM
    HID_DIM = args.HID_DIM
    KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT

    STARTING_TEMP = args.STARTING_TEMP
    TEMP_MIN = args.TEMP_MIN
    ANNEAL_RATE = args.ANNEAL_RATE

    NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE

    #     transform = Compose(
    #         [
    #             RandomResizedCrop(args.image_size, args.image_size),
    #             OneOf(
    #                 [
    #                     IAAAdditiveGaussianNoise(),
    #                     GaussNoise(),
    #                 ],
    #                 p=0.2
    #             ),
    #             VerticalFlip(p=0.5),
    #             OneOf(
    #                 [
    #                     MotionBlur(p=.2),
    #                     MedianBlur(blur_limit=3, p=0.1),
    #                     Blur(blur_limit=3, p=0.1),
    #                 ],
    #                 p=0.2
    #             ),
    #             OneOf(
    #                 [
    #                     CLAHE(clip_limit=2),
    #                     IAASharpen(),
    #                     IAAEmboss(),
    #                     RandomBrightnessContrast(),
    #                 ],
    #                 p=0.3
    #             ),
    #             HueSaturationValue(p=0.3),
    # #             Normalize(
    # #                 mean=[0.485, 0.456, 0.406],
    # #                 std=[0.229, 0.224, 0.225],
    # #             )
    #         ],
    #         p=1.0
    #     )

    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])

    sampler = None
    dl = None

    # data
    logger.debug(f"IMAGE_PATH : {IMAGE_PATH}")
    #     ds = AlbumentationImageDataset(
    #         IMAGE_PATH,
    #         transform=transform,
    #         args=args
    #     )
    ds = ImageFolder(
        IMAGE_PATH,
        transform=transform,
    )

    if args.model_parallel and (args.ddp
                                or args.horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        ds = SplitDataset(ds, partitions=partitions_dict)
        ds.select(f"{smp.dp_rank()}")

    dl = DataLoader(ds,
                    BATCH_SIZE,
                    shuffle=True,
                    drop_last=args.model_parallel,
                    **args.kwargs)

    vae_params = dict(image_size=IMAGE_SIZE,
                      num_layers=NUM_LAYERS,
                      num_tokens=NUM_TOKENS,
                      codebook_dim=EMB_DIM,
                      hidden_dim=HID_DIM,
                      num_resnet_blocks=NUM_RESNET_BLOCKS)

    vae = DiscreteVAE(**vae_params,
                      smooth_l1_loss=SMOOTH_L1_LOSS,
                      kl_div_loss_weight=KL_LOSS_WEIGHT).to(device)
    # optimizer

    opt = Adam(vae.parameters(), lr=LEARNING_RATE)
    sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE)

    if args.model_parallel:
        import copy
        dummy_codebook = copy.deepcopy(vae.codebook)
        dummy_decoder = copy.deepcopy(vae.decoder)

        vae = smp.DistributedModel(vae)
        scaler = smp.amp.GradScaler()
        opt = smp.DistributedOptimizer(opt)

        if args.partial_checkpoint:
            args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
        elif args.full_checkpoint:
            args.checkpoint = smp.load(args.full_checkpoint, partial=False)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])

    assert len(ds) > 0, 'folder does not contain any images'

    if (not args.model_parallel) and args.rank == 0:
        print(f'{len(ds)} images found for training')

        # weights & biases experiment tracking

        #         import wandb

        model_config = dict(num_tokens=NUM_TOKENS,
                            smooth_l1_loss=SMOOTH_L1_LOSS,
                            num_resnet_blocks=NUM_RESNET_BLOCKS,
                            kl_loss_weight=KL_LOSS_WEIGHT)

#         run = wandb.init(
#             project = 'dalle_train_vae',
#             job_type = 'train_model',
#             config = model_config
#         )

    def save_model(path):
        if not args.rank == 0:
            return

        save_obj = {'hparams': vae_params, 'weights': vae.state_dict()}

        torch.save(save_obj, path)

    # distribute with deepspeed
    if not args.model_parallel:
        deepspeed_utils.check_batch_size(BATCH_SIZE)
        deepspeed_config = {'train_batch_size': BATCH_SIZE}

        (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute(
            args=args,
            model=vae,
            optimizer=opt,
            model_parameters=vae.parameters(),
            training_data=ds if args.deepspeed else dl,
            lr_scheduler=sched,
            config_params=deepspeed_config,
        )

    try:
        # Rubik: Define smp.step. Return any tensors needed outside.
        @smp.step
        def train_step(vae, images, temp):
            #             logger.debug(f"args.amp : {args.amp}")
            with autocast(enabled=(args.amp > 0)):
                loss, recons = vae(images,
                                   return_loss=True,
                                   return_recons=True,
                                   temp=temp)

            scaled_loss = scaler.scale(loss) if args.amp else loss
            vae.backward(scaled_loss)
            #             torch.nn.utils.clip_grad_norm_(vae.parameters(), 5)
            return loss, recons

        @smp.step
        def get_codes_step(vae, images, k):
            images = images[:k]
            logits = vae.forward(images, return_logits=True)
            codebook_indices = logits.argmax(dim=1).flatten(1)
            return codebook_indices

        def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices):
            from functools import partial
            for module in dummy_codebook.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            image_embeds = dummy_codebook.forward(codebook_indices)
            b, n, d = image_embeds.shape
            h = w = int(sqrt(n))

            image_embeds = rearrange(image_embeds,
                                     'b (h w) d -> b d h w',
                                     h=h,
                                     w=w)
            for module in dummy_decoder.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            hard_recons = dummy_decoder.forward(image_embeds)
            return hard_recons

    except:
        pass

    # starting temperature

    global_step = 0
    temp = STARTING_TEMP

    for epoch in range(EPOCHS):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(dl), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        vae.train()
        start = time.time()

        for i, (images, _) in enumerate(dl):
            images = images.to(device, non_blocking=True)
            opt.zero_grad()

            if args.model_parallel:
                loss, recons = train_step(vae, images, temp)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()
                recons = recons.reduce_mean()
            else:
                loss, recons = distr_vae(images,
                                         return_loss=True,
                                         return_recons=True,
                                         temp=temp)

            if (not args.model_parallel) and args.deepspeed:
                # Gradients are automatically zeroed after the step
                distr_vae.backward(loss)
                distr_vae.step()
            elif args.model_parallel:
                if args.amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(vae.local_parameters())) > 0:
                        opt.step()
            else:
                loss.backward()
                opt.step()

            logs = {}

            if i % 10 == 0:
                if args.rank == 0:
                    #                 if deepspeed_utils.is_root_worker():
                    k = NUM_IMAGES_SAVE

                    with torch.no_grad():
                        if args.model_parallel:
                            model_dict = vae.state_dict()
                            model_dict_updated = {}
                            for key, val in model_dict.items():
                                if "decoder" in key:
                                    key = key.replace("decoder.", "")
                                elif "codebook" in key:
                                    key = key.replace("codebook.", "")
                                model_dict_updated[key] = val

                            dummy_decoder.load_state_dict(model_dict_updated,
                                                          strict=False)
                            dummy_codebook.load_state_dict(model_dict_updated,
                                                           strict=False)
                            codes = get_codes_step(vae, images, k)
                            codes = codes.reduce_mean().to(torch.long)
                            hard_recons = hard_recons_step(
                                dummy_decoder, dummy_codebook, codes)
                        else:
                            codes = vae.get_codebook_indices(images[:k])
                            hard_recons = vae.decode(codes)

                    images, recons = map(lambda t: t[:k], (images, recons))
                    images, recons, hard_recons, codes = map(
                        lambda t: t.detach().cpu(),
                        (images, recons, hard_recons, codes))
                    images, recons, hard_recons = map(
                        lambda t: make_grid(t.float(),
                                            nrow=int(sqrt(k)),
                                            normalize=True,
                                            range=(-1, 1)),
                        (images, recons, hard_recons))

#                     logs = {
#                         **logs,
#                         'sample images':        wandb.Image(images, caption = 'original images'),
#                         'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
#                         'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
#                         'codebook_indices':     wandb.Histogram(codes),
#                         'temperature':          temp
#                     }

                if args.model_parallel:
                    filename = f'{args.model_dir}/vae.pt'
                    if smp.dp_rank == 0:
                        if args.save_full_model:
                            model_dict = vae.state_dict()
                            opt_dict = opt.state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=False,
                            )
                        else:
                            model_dict = vae.local_state_dict()
                            opt_dict = opt.local_state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=True,
                            )
                    smp.barrier()

                else:
                    save_model(f'{args.model_dir}/vae.pt')
    #                     wandb.save(f'{args.model_dir}/vae.pt')

    # temperature anneal

                temp = max(temp * math.exp(-ANNEAL_RATE * global_step),
                           TEMP_MIN)

                # lr decay

                sched.step()

            # Collective loss, averaged
            if args.model_parallel:
                avg_loss = loss.detach().clone()
                #                 print("args.world_size : {}".format(args.world_size))
                avg_loss /= args.world_size

            else:
                avg_loss = deepspeed_utils.average_all(loss)

            if args.rank == 0:
                if i % 100 == 0:
                    lr = sched.get_last_lr()[0]
                    print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},')

                    logs = {
                        **logs, 'epoch': epoch,
                        'iter': i,
                        'loss': avg_loss.item(),
                        'lr': lr
                    }

#                 wandb.log(logs)
            global_step += 1

            if args.rank == 0:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                # Measure accuracy
                #                 prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

                # to_python_float incurs a host<->device sync
                losses.update(util.to_python_float(loss), images.size(0))
                #                 top1.update(util.to_python_float(prec1), images.size(0))
                #                 top5.update(util.to_python_float(prec5), images.size(0))

                # Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - start) / args.log_interval)
                end = time.time()

                print(
                    'Epoch: [{0}][{1}/{2}] '
                    'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, '
                    'Train_Speed={3:.3f} ({4:.3f}), '
                    'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format(
                        epoch,
                        i,
                        len(dl),
                        args.world_size * BATCH_SIZE / batch_time.val,
                        args.world_size * BATCH_SIZE / batch_time.avg,
                        batch_time=batch_time,
                        loss=losses))

#         if deepspeed_utils.is_root_worker():
# save trained model to wandb as an artifact every epoch's end

#             model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#             model_artifact.add_file(f'{args.model_dir}/vae.pt')
#             run.log_artifact(model_artifact)

    if args.rank == 0:
        #     if deepspeed_utils.is_root_worker():
        # save final vae and cleanup
        if args.model_parallel:
            logger.debug('save model_parallel')
        else:
            save_model(os.path.join(args.model_dir, 'vae-final.pt'))


#         wandb.save(f'{args.model_dir}/vae-final.pt')

#         model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#         model_artifact.add_file(f'{args.model_dir}/vae-final.pt')
#         run.log_artifact(model_artifact)

#         wandb.finish()

    if args.model_parallel:
        if args.assert_losses:
            if args.horovod or args.ddp:
                # SM Distributed: If using data parallelism, gather all losses across different model
                # replicas and check if losses match.

                losses = smp.allgather(loss, smp.DP_GROUP)
                for l in losses:
                    print(l)
                    assert math.isclose(l, losses[0])

                assert loss < 0.18
            else:
                assert loss < 0.08

        smp.barrier()
        print("SMP training finished successfully")