コード例 #1
0
ファイル: dalle.py プロジェクト: deepglugs/dalle
def get_dalle(vae, vocab, args):
    dalle = DALLE(dim=args.codebook_dims,
                  vae=vae,
                  num_text_tokens=len(vocab) + 1,
                  text_seq_len=len(vocab),
                  depth=16,
                  heads=8,
                  dim_head=64,
                  attn_dropout=0.1,
                  ff_dropout=0.1,
                  reversible=True)

    if args.dalle is not None and os.path.isfile(args.dalle):
        print(f"loading state dict from {args.dalle}")
        dalle.load_state_dict(torch.load(args.dalle))

    dalle.to(args.device)
    vae.to(args.device)

    return dalle
コード例 #2
0
    vae,  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens=10000,  # vocab size for text
    text_seq_len=256,  # text sequence length
    depth=6,  # should be 64
    heads=8,  # attention heads
    dim_head=64,  # attention head dimension
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1  # feedforward dropout
)

# load pretrained dalle if continuing training

dalle_dict = torch.load(loadfn)
dalle.load_state_dict(dalle_dict)

dalle.to(device)

# get image and text data

lf = open("od-captionsonly.txt",
          "r")  # file contains captions only, one caption per line

# build vocabulary

from Vocabulary import Vocabulary

vocab = Vocabulary("captions")

captions = []
for lin in lf:
    captions.append(lin)