Пример #1
0
    def __init__(self,cfg,path,voc):
        self.cfg = cfg
        self.voc = voc
        self.path = path
        self.data_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])

        self.coco_train = dset.CocoCaptions(root=self.path.train_image_path,annFile=self.path.train_annotation_file,transform=self.data_transform)
        self.coco_val = dset.CocoCaptions(root=self.path.val_image_path,annFile=self.path.val_annotation_file,transform=self.data_transform) 

        self.id2fname = {}
        test_info = json.load(open(self.path.test_info_path))
        for img in test_info['images']:
            self.id2fname[img['id']] = img['file_name']
Пример #2
0
def get_coco_data(vocab,
                  train=True,
                  img_size=224,
                  scale_size=256,
                  normalize=__normalize):
    if train:
        root, annFile = __TRAIN_PATH['root'], __TRAIN_PATH['annFile']
        img_transform = transforms.Compose([
            transforms.Scale(scale_size),
            transforms.RandomCrop(img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(**normalize)
        ])
    else:
        root, annFile = __VAL_PATH['root'], __VAL_PATH['annFile']
        img_transform = transforms.Compose([
            transforms.Scale(scale_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(**normalize)
        ])
    data = (dset.CocoCaptions(root=root,
                              annFile=annFile,
                              transform=img_transform,
                              target_transform=create_target(vocab,
                                                             train)), vocab)
    return data
Пример #3
0
 def prepare(self):
     self.word2idx = defaultdict(int)
     # to make sure start_symbol, end_symbol, pad, and unk will be included
     self.word2idx[self.START_SYMBOL] = self.word2idx[
         self.END_SYMBOL] = self.word2idx[self.UNK] = self.word2idx[
             self.PAD] = self.min_word_freq
     for dataset_type in ["train", "val"]:
         caps = dset.CocoCaptions(
             root=FilePathManager.resolve(f'data/{dataset_type}'),
             annFile=FilePathManager.resolve(
                 f"data/annotations/captions_{dataset_type}2017.json"),
             transform=transforms.ToTensor())
         for _, captions in caps:
             for capt in captions:
                 tokens = self.tokenize(capt)
                 for token in tokens:
                     self.word2idx[token] += 1
     temp = {}
     embeddings = {}
     fast_text = FastText.load(
         FilePathManager.resolve("data/fasttext.model"), mmap="r")
     for k, v in self.word2idx.items():
         if v >= self.min_word_freq:
             temp[k] = len(temp)
             embeddings[k] = fast_text[k] if k in fast_text else fast_text[
                 self.UNK]
     self.word2idx = temp
     # swap keys and values
     self.idx2word = dict(zip(self.word2idx.values(), self.word2idx.keys()))
     self.fast_text = embeddings
Пример #4
0
def extract_coco_feature(model, batch_size=64):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224, pad_if_needed=True),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    data = datasets.CocoCaptions('./data/train2014',
                                 './data/annotations/captions_train2014.json',
                                 transform=transform)

    loader = dataloader.DataLoader(data, batch_size=batch_size)
    model.eval()

    features = []
    captions = []

    length = len(loader)
    for i, (images, captions) in enumerate(loader):
        if torch.cuda.is_available():
            images = images.cuda()
        feature = model(images)
        features.append(feature.cpu())
        captions.extend(captions[0])

        if i % 100 == 0:
            print('[%d/%d] finished' % (i, length))

    return torch.cat(features, 0), captions
Пример #5
0
    def __init__(self, images, annotations, augment, normalize=True):
        """
        Initializes the dataset.

        Parameters:
            images:
                The path of the directory that contains the images
            annotations:
                The path of the directory that contains the captions
            augment:
                Whether to use data augmentation techniques
            normalize:
                Whether to normalize the images
        """

        trans = transforms.Compose([_imageAugment, _imagePrep
                                    ]) if augment else _imagePrep
        if normalize:
            trans = transforms.Compose([
                trans,
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        self.coco = datasets.CocoCaptions(root=images,
                                          annFile=annotations,
                                          transform=trans)
def dataloader(image_folder, captions_file, batch_size, stoi):

    spell = SpellChecker(distance=1)

    #Definir el tensor para guardar las imagenes de un batch
    tensor_images = torch.zeros((batch_size, 3, 224, 224)).to(device)
    #Definir las transformaciones que se aplican a las imagenes
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224, 0.225])])

    #Cargar el dataset
    cap = datasets.CocoCaptions(root = image_folder,
                                annFile = captions_file,
                                transform = transform)

    for batch in range(int( len(cap)/batch_size )):
    # for batch in range(15):
        captions = []
        for i in range(batch_size):
            #Obtener una imagen con sus captions y seleccionar uno al azar
            img, target = cap[batch*batch_size+i]

            # Pasar la frase a todo minusculas y separar signos de puntuacion
            captions.append([word if word in stoi else spell.correction(word) for word in re.findall(r"[\w']+|[.,!?;]", target[random.randrange(5)].lower())])

            #Actualizar el tensor de imagenes
            tensor_images[i] = img

        yield tensor_images, captions
Пример #7
0
def get_dataset(data_set):
    assert data_set in ['train', 'val', 'test']
    data_root = f'cocoapi/images/{data_set}2017/'
    annfile = f'cocoapi/annotations/captions_{data_set}2017.json'
    dataset = dset.CocoCaptions(root=data_root,
                                annFile=annfile,
                                transform=picture_tranform_func)
    return dataset
Пример #8
0
 def __init__(self, corpus: Corpus, transform=None, captions_per_image=2):
     self.corpus = corpus
     self.captions = dset.CocoCaptions(
         root=FilePathManager.resolve(f'data/train'),
         annFile=FilePathManager.resolve(
             f"data/annotations/captions_train2017.json"),
         transform=transform)
     self.captions_per_image = captions_per_image
Пример #9
0
 def __init__(self, corpus: Corpus, evaluator: bool = True, transform=None):
     self.corpus = corpus
     self.evaluator = evaluator
     self.captions = dset.CocoCaptions(
         root=FilePathManager.resolve(f'data/train'),
         annFile=FilePathManager.resolve(
             f"data/annotations/captions_train2017.json"),
         transform=transform)
Пример #10
0
    def __init__(self, tranform=None):
        self.captions = dset.CocoCaptions(
            root=FilePathManager.resolve(f'data/train'),
            annFile=FilePathManager.resolve(
                f"data/annotations/captions_train2017.json"),
            transform=tranform)

        self.length = len(self.captions)
        self.s = set(range(self.length))
 def __init__(self, corpus: Corpus):
     self.corpus = corpus
     self.captions = dset.CocoCaptions(root=FilePathManager.resolve(f'data/train'),
                                       annFile=FilePathManager.resolve(
                                           f"data/annotations/captions_train2017.json"),
                                       transform=transforms.ToTensor())
     with open(FilePathManager.resolve("data/embedded_images.pkl"), "rb") as f:
         self.images = pickle.load(f)
     self.length = len(self.images) * 5
Пример #12
0
def get_val_loader(root, annFile, transform, shuffle=False, num_workers=8):
    val_set = datasets.CocoCaptions(root,
                                    annFile,
                                    transform=transform,
                                    target_transform=None)
    val_set_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=shuffle,
                                                 num_workers=num_workers)
    return val_set_loader
Пример #13
0
    def __init__(self, root, img_transform=imagenet_transform,
                 split='train',
                 tokenization='bpe',
                 num_symbols=32000,
                 shared_vocab=True,
                 code_file=None,
                 vocab_file=None,
                 insert_start=[BOS], insert_end=[EOS],
                 mark_language=False,
                 tokenizer=None,
                 sample_caption=True):
        super(CocoCaptions, self).__init__()
        self.shared_vocab = shared_vocab
        self.num_symbols = num_symbols
        self.tokenizer = tokenizer
        self.tokenization = tokenization
        self.insert_start = insert_start
        self.insert_end = insert_end
        self.mark_language = mark_language
        self.code_file = code_file
        self.vocab_file = vocab_file
        self.sample_caption = None
        self.img_transform = img_transform
        if split == 'train':
            path = {'root': os.path.join(root, 'train2014'),
                    'annFile': os.path.join(root, 'annotations/captions_train2014.json')
                    }
            if sample_caption:
                self.sample_caption = randrange
        else:
            path = {'root': os.path.join(root, 'val2014'),
                    'annFile': os.path.join(root, 'annotations/captions_val2014.json')
                    }
            if sample_caption:
                self.sample_caption = lambda l: 0

        self.data = dset.CocoCaptions(root=path['root'], annFile=path[
                                      'annFile'], transform=img_transform(train=(split == 'train')))

        if self.tokenizer is None:
            prefix = os.path.join(root, 'coco')
            if tokenization not in ['bpe', 'char', 'word']:
                raise ValueError("An invalid option for tokenization was used, options are {0}".format(
                    ','.join(['bpe', 'char', 'word'])))

            if tokenization == 'bpe':
                self.code_file = code_file or '{prefix}.{lang}.{tok}.codes_{num_symbols}'.format(
                    prefix=prefix, lang='en', tok=tokenization, num_symbols=num_symbols)
            else:
                num_symbols = ''

            self.vocab_file = vocab_file or '{prefix}.{lang}.{tok}.vocab{num_symbols}'.format(
                prefix=prefix, lang='en', tok=tokenization, num_symbols=num_symbols)
            self.generate_tokenizer()
Пример #14
0
 def __init__(self, coco_path, annFile, batch_size):
     self.coco_data = datasets.CocoCaptions(root=coco_path, annFile=annFile)
     self.paths = np.arange(len(self.coco_data))
     self.index = 0
     self.batch_size = batch_size
     self.init_count = 0
     self.lock = threading.Lock()  #mutex for input path
     self.yield_lock = threading.Lock(
     )  #mutex for generator yielding of batch
     self.path_id_generator = threadsafe_iter(get_path_i(len(self.paths)))
     self.cumulative_batch = []
Пример #15
0
def get_data_loader(image_folder, annotation_file, batch_size=32):
    """
    Get the torch data loader for the COCO dataset
    :param image_folder: The folder containing the COCO images
    :param annotation_file: The .json file for the image annotations
    :param batch_size: The batch size
    :return: A COCO captioning data loader
    """
    # Build the torch dataset
    coco_dataset = torchdata.CocoCaptions(image_folder, annotation_file)
    # Return a dataloader object over this dataset
    return DataLoader(coco_dataset, batch_size=batch_size)
Пример #16
0
def evaluate(image: str = None):
    model = VisualBERT.load_from_checkpoint(
        "models/electra/final-year-project/3osm0cr3/checkpoints/epoch=17-step=231535.ckpt",
        manual_lm_head=True)
    model.training_objective = TrainingObjective.Captioning
    model.cuda()
    model.eval()

    feature_model = FeatureExtractor()
    feature_model.cuda()
    feature_model.eval()

    if image is None:
        data = CocoCaptions()
        data.prepare_data()
        data.setup()
        images = "./data/raw/coco/train2017"
        annotations = "./data/raw/coco/annotations/captions_train2017.json"
        raw_dataset = tv.CocoCaptions(root=images,
                                      annFile=annotations,
                                      transform=load_image)
        raw_subset = Subset(raw_dataset, data.val.indices)
        dataloader = DataLoader(raw_subset, batch_size=16, num_workers=0)

        # %%
        batch = next(iter(dataloader))
        images, caption_sets = batch
    else:
        image = PIL.Image.open(image)
        images = load_image(image).unsqueeze(0).cuda()
        caption_sets = []

    fig = plt.figure()
    num_rows = math.ceil(math.sqrt(images.shape[0]))
    num_cols = num_rows
    for i, image in enumerate(images.split(1)):
        features = feature_model(image)[0]
        features, mask = CocoCaptionsDataset.postprocess_features(features, 8)
        features = features.unsqueeze(0)
        mask = mask.unsqueeze(0)
        generated = model.inference(features.cuda(), mask.cuda(), 20)
        targets = [x[i] for x in caption_sets]
        ax = fig.add_subplot(num_rows, num_cols, i + 1)
        plt.imshow(image.cpu().squeeze().flip(2) / 255)

        ax.set_xlabel(generated)

        bleu = sentence_bleu([word_tokenize(x) for x in targets],
                             word_tokenize(generated))

        print(f"{generated=}, {targets=}, {bleu}")
    plt.show()
Пример #17
0
def createcaptiontxt():
    coco_train = dset.CocoCaptions(
        root='E:/datasets/coco2017/train2017',
        annFile='E:/datasets/coco2017/annotations/captions_train2017.json',
        transform=transforms.ToTensor())

    print('Number of samples: ', len(coco_train))
    f = open('cocotxt/caption.txt', 'w')
    for i in tqdm(range(len(coco_train))):
        _, target = coco_train[i]
        f.write(''.join(target).replace("\n", ""))
        f.write('\n')
    print('create caption txt finish!')
Пример #18
0
def load_dataset(root,
                 transform,
                 *args,
                 batch_size=32,
                 shuffle=True,
                 dataset_type='folder',
                 **kwargs):
    """
    Parameters
    -----------

    dataset_type: str
        should be voc , coco, cifar, minst or folder
        if you're using voc dataset then you have to pass a param as year = 2007 or 2012
        if you're using coco dataset then you have to pass a param as type = 'detection' or 'caption'
    
    Return
    ----------
    data: Dataloader

    dataset: torchvision.dataset

    """
    if dataset_type == 'folder':
        dataset = datasets.ImageFolder(root, transform=transform)

    elif dataset_type == 'voc':
        year = kwargs.get('year', 2007)
        image_set = kwargs.get('image_set', 'train')
        dataset = datasets.VOCDetection(root,
                                        year=year,
                                        image_set=image_set,
                                        transform=transform)

    elif dataset_type == 'coco':
        assert 'type' in kwargs and 'annfile' in kwargs
        annfile = kwargs['annfile']
        type = kwargs['type']
        if type == 'detection':
            dataset = datasets.CocoDetection(root,
                                             annFile=annfile,
                                             transform=transform)
        elif type == 'caption':
            dataset = datasets.CocoCaptions(root,
                                            annFile=annfile,
                                            transform=transform)

    data = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return data, dataset
Пример #19
0
    def __init__(self, root, split, max_detections=50, sort_by_prob=False):
        self.split = split
        self.root = root
        self.max_detections = max_detections
        self.sort_by_prob = sort_by_prob

        self.dset_captions = dset.CocoCaptions(
            root='/netscratch/karayil/mscoco/data/' + split + '2014',
            annFile="/netscratch/karayil/mscoco/data/annotations/captions_" +
            split + "2014.json",
            transform=transforms.ToTensor())

        self.coco = self.dset_captions.coco

        self.build_samples()
Пример #20
0
def extract_feats(args):
    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)

    cap = dset.CocoCaptions(
        root='/Pulsar1/Datasets/coco/train2014/train2014',
        annFile=
        '/Neutron9/sahil.c/datasets/annotations/captions_train2014.json',
        transform=transforms.ToTensor())

    print('Number of samples: ', len(cap))
    for i, t in cap:
        image = i.unsqueeze(0)
        image = Variable(utils.preprocess_batch(image), requires_grad=False)
        image = utils.subtract_imagenet_mean_batch(image)
        features_content = vgg(image)
Пример #21
0
    def __init__(self,
                 data_dir,
                 ann_file,
                 embedding_type='cnn-rnn',
                 emb_model=None,
                 imsize=64,
                 transform=None,
                 target_transform=None):

        self.transform = transform
        self.target_transform = target_transform
        self.imsize = imsize
        self.glove = self.load_embedding(emb_model)
        # if data_dir.find('birds') != -1:
        #     self.bbox = self.load_bbox()
        # else:
        #     self.bbox = None
        # split_dir = os.path.join(data_dir, split)
        self.cap = dset.CocoCaptions(root=data_dir, annFile=ann_file)
        self.idx2word = None
Пример #22
0
def main(args):
    use_cuda = torch.cuda.is_available()
    torch.manual_seed(random.randint(1, 10000))
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    words = load_vocab()
    vocab = {i: w for w, i in words.items()}

    coco = datasets.CocoCaptions(args.root_dir, args.anno_path)
    mycoco = MyCoco(
        words,
        args.root_dir,
        args.anno_path,
        transform=transforms.Compose([
            transforms.Resize([args.im_size] * 2),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.407, 0.457, 0.485],  # subtract imagenet mean
                std=[1, 1, 1]),
        ]))
    model = Captor(args.lr, args.weight_decay, args.lr_decay_rate, len(words),
                   args.embed_size)
    model.to(device)

    model.load_checkpoint(args.ckpt_path)

    print('dataset length {}'.format(len(mycoco)))
    score = 0
    for i in range(len(mycoco)):
        im, cap_enc = mycoco[i]
        im_, caps = coco[i]
        pred = beamsearch(model, device, im, vocab, return_sentence=False)
        s = utils.bleu_score(utils.to_word_bags(caps), pred, n=args.bleu_n)
        score = (score * i + s) / (i + 1)
        print('processing {}th image... score: {:.2f}'.format(i, score),
              flush=True,
              end='\r')

    print('\navg bleu score: {}'.format(score))
Пример #23
0
def get_coco_data_raw(train=True, img_size=224, scale_size=256):
    if train:
        root, annFile = __TRAIN_PATH['root'], __TRAIN_PATH['annFile']
        img_transform = transforms.Compose([
            transforms.Scale(scale_size),
            transforms.RandomCrop(img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        root, annFile = __VAL_PATH['root'], __VAL_PATH['annFile']
        img_transform = transforms.Compose([
            transforms.Scale(scale_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    data = dset.CocoCaptions(root=root,
                             annFile=annFile,
                             transform=img_transform)
    return data
Пример #24
0
def load_dataset(root,
                 transform,
                 batch_size=32,
                 shuffle=True,
                 dataset_type='folder',
                 *args,
                 **kwargs):
    """
    param
    dataset_type: str
        should be voc , coco, cifar, minst or folder
    
    """
    if dataset_type == 'folder':
        dataset = datasets.ImageFolder(root, transform=transform)

    elif dataset_type == 'voc':
        year = kwargs['year']
        image_set = kwargs['image_set']
        dataset = datasets.VOCDetection(root,
                                        year=year,
                                        image_set=image_set,
                                        transform=transform)
    elif dataset_type == 'coco':
        annfile = kwargs['annfile']
        type = kwargs['type']
        if type == 'detect':
            dataset = datasets.CocoDetection(root,
                                             annFile=annfile,
                                             transform=transform)
        elif type == 'caption':
            dataset = datasets.CocoCaptions(root,
                                            annFile=annfile,
                                            transform=transform)

    data = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return data, dataset.classes, dataset.class_to_idx
Пример #25
0
def getData(dset_name, batch_size, data_transform):
    dataPath = "../data"
    os.makedirs(dataPath, exist_ok=True)
    if dset_name == "CIFAR10":
        trainset = dset.CIFAR10(dataPath,
                                train=True,
                                download=True,
                                transform=data_transform)
        testset = dset.CIFAR10(dataPath,
                               train=False,
                               download=True,
                               transform=data_transform)
    elif dset_name == "LSUN":
        trainset = dset.LSUN(dataPath,
                             train=True,
                             download=True,
                             transform=data_transform)
        testset = dset.LSUN(dataPath,
                            train=False,
                            download=True,
                            transform=data_transform)
    elif dset_name == "FakeData":
        trainset = dset.FakeData(dataPath,
                                 train=True,
                                 download=True,
                                 transform=data_transform)
        testset = dset.FakeData(dataPath,
                                train=False,
                                download=True,
                                transform=data_transform)
    elif dset_name == "CocoCaptions":
        trainset = dset.CocoCaptions(dataPath,
                                     train=True,
                                     download=True,
                                     transform=data_transform)
        testset = dset.CocoCaptions(dataPath,
                                    train=False,
                                    download=True,
                                    transform=data_transform)
    elif dset_name == "MNIST":
        trainset = dset.MNIST(dataPath,
                              train=True,
                              download=True,
                              transform=data_transform)
        testset = dset.MNIST(dataPath,
                             train=False,
                             download=True,
                             transform=data_transform)
    elif dset_name == "CIFAR100":
        trainset = dset.CIFAR100(dataPath,
                                 train=True,
                                 download=True,
                                 transform=data_transform)
        testset = dset.CIFAR100(dataPath,
                                train=False,
                                download=True,
                                transform=data_transform)
    elif dset_name == "SVHN":
        trainset = dset.SVHN(dataPath,
                             train=True,
                             download=True,
                             transform=data_transform)
        testset = dset.SVHN(dataPath,
                            train=False,
                            download=True,
                            transform=data_transform)
    elif dset_name == "Flickr8k":
        trainset = dset.Flickr8k(dataPath,
                                 train=True,
                                 download=True,
                                 transform=data_transform)
        testset = dset.Flickr8k(dataPath,
                                train=False,
                                download=True,
                                transform=data_transform)
    elif dset_name == "Cityscapes":
        trainset = dset.Cityscapes(dataPath,
                                   train=True,
                                   download=True,
                                   transform=data_transform)
        testset = dset.Cityscapes(dataPath,
                                  train=False,
                                  download=True,
                                  transform=data_transform)
    return torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True),len(trainset),\
           torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True), len(testset),
Пример #26
0
    # Merge image tensors (stack)
    images = torch.stack(images, 0)

    # Merge captions
    caption_lengths = [len(caption) for caption in captions]

    # zero-matrix num_captions x caption_max_length
    padded_captions = pad_sequence(captions, padding_value=0, batch_first=True)
    return images, padded_captions, caption_lengths#, caption_lengths


if __name__ == '__main__':
    root = './data/train2014'
    annot = './data/annotations/captions_train2014.json'

    coco = datasets.CocoCaptions(root, annot, transform=transforms.ToTensor())

    words = word_dict(coco)
    with open('./data/vocab.txt', 'w') as f:
        for w in words:
            f.write(w + '\n')

    vocab = load_vocab('data/vocab.txt')
    print(len(vocab))
    print(len(coco))
    im, enc = coco[5]
    im = transforms.ToPILImage()(im)
    im.show()

    for i in range(len(coco)):
        _ = coco[i]
sys.path.append('/home/xuxin/work/python/data/coco-master/PythonAPI')

import torchvision.datasets as dset
import torchvision.transforms as transforms

dataDir = '/home/xuxin/work/python/data'
imageDir = '%s/%s' % (dataDir, args.data)
capDir = '%s/Micro-coco/annotations/captions_%s.json' % (dataDir, args.data)
transforms = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

cap = dset.CocoCaptions(root=imageDir, annFile=capDir, transform=transforms)
print type(cap)
print('Number of samples: ', len(cap))

########################################################################################################
'''
check one, give input
'''
########################################################################################################
import time

isLSTM = (args.mode == 'LSTM')
check_flag1 = False
if check_flag1:

    criterion = nn.CrossEntropyLoss()
Пример #28
0
    straight_through=
    False  # straight-through for gumbel softmax. unclear if it is better one way or the other
).cuda()

optimizerVAE = torch.optim.Adam(vae.parameters(), lr=learning_rate)
"""
text = torch.randint(0, NUM_TOKENS, (BATCH_SIZE, TEXTSEQLEN))
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)
mask = torch.ones_like(text).bool()
"""

cap = dset.CocoCaptions(
    root='./coco/images',
    annFile='./coco/annotations/captions_val2014.json',
    transform=transforms.Compose([
        #transforms.RandomCrop((IMAGE_SIZE,IMAGE_SIZE),pad_if_needed=True),
        #transforms.Grayscale(),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]))

tokenDset = token_dataset('./coco/merged-1000.txt')

VAEloss = []

for epoch in range(EPOCHS):
    for i in range(DATASET_SIZE):
        #print(i,":",tokenDset.getRand(i),img.size())
        optimizerVAE.zero_grad()
        img, _ = cap[i]
        img = img.unsqueeze(0).cuda()
Пример #29
0
import torchvision.datasets as dset
import torchvision.transforms as transforms
from JPtokenize import token_dataset
import torch
from dalle_pytorch import DiscreteVAE, DALLE

IMAGE_SIZE = 256

cap = dset.CocoCaptions(root='./coco/images',
                        annFile='./coco/annotations/captions_val2014.json',
                        transform=transforms.Compose([
                            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                            transforms.ToTensor(),
                        ]))

tokenDset = token_dataset('./coco/merged.txt')

print('Number of samples: ', len(cap))
#img, target = cap[3] # load 4th sample
L = []

print('Max len', tokenDset.maxLen())

for i, (img, target) in enumerate(cap):
    print(i, ":", tokenDset.getRand(i), img.size())
    if i > 10:
        break
Пример #30
0
import torchvision.datasets as dset
import torchvision.transforms as transforms

cap = dset.CocoCaptions(
    root='./coco/images',
    annFile='./coco/annotations/captions_val2014.json',
)
#transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
#img, target = cap[3] # load 4th sample
L = []

for i, (img, target) in enumerate(cap):
    for s in target:
        L.append("{}|{}\n".format(i, s))
        print("{}|{}".format(i, s))

with open("ann_export.txt", "w", encoding='utf-8') as Fp:
    Fp.writelines(L)