コード例 #1
0
    if data_type == 'cartoon':
        t = [
            transforms.ToPILImage(),
            ImgAugTransformCartoon(), lambda x: PIL.Image.fromarray(x),
            transforms.ToTensor(), data_normalization
        ]
    if data_type == 'snow':
        t = [
            transforms.ToPILImage(),
            ImgAugTransformSnow(), lambda x: PIL.Image.fromarray(x),
            transforms.ToTensor(), data_normalization
        ]

    test_loader = torch.utils.data.DataLoader(CaptionDataset(
        '../../output_folder',
        data_name,
        'TEST',
        transform=transforms.Compose(t)),
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=True)

    if data_type == 'custom':
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(),
            # data_normalization #notice: normalize happenes inside Custom_Image_Dataset
        ])
        path = '/Users/gallevshalev/Desktop/datasets/custom'
        data = Custom_Image_Dataset(path, transform, data_normalization)
コード例 #2
0
for i in cB.items():
    word_index = word_map[i[0]]
    v = np.zeros(i[1])
    v.fill(word_index)
    hisB.append(v)
hisB = np.concatenate(hisB)
plt.hist(hisB, density=True, bins=9490)
plt.show()

if write:
    f.close()

# sec: debug
if debug == True:
    d = 0
    print(' '.join(
        list(filter(lambda x: x['image_id'] == 315,
                    generated_sentencesA))[0]['caption']))
    print(' '.join(
        list(filter(lambda x: x['image_id'] == 315,
                    generated_sentencesB))[0]['caption']))

    f = CaptionDataset(coco_data_path,
                       data_name,
                       'TEST',
                       transform=transforms.Compose([data_normalization]))
    # test_loader = torch.utils.data.DataLoader(
    #     CaptionDataset(coco_data_path, data_name, 'TEST', transform=transforms.Compose([data_normalization])),
    #     batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
    d = 0
コード例 #3
0
    save_dir_name = '{}_{}'.format(args.beam_size, args.save_dir_name)
    model_path, save_dir = get_model_path_and_save_path(args, save_dir_name)

    # Load model
    encoder, decoder = get_models(model_path)

    # Create rev word map
    word_map, rev_word_map = get_word_map()

    test_loader_pertubeded = torch.utils.data.DataLoader(
        CaptionDataset(
            '../../../output_folder',
            data_name,
            'TEST',
            transform=transforms.Compose([
                transforms.ToPILImage(),
                ImgAugTransformJpegCompression(),
                # ImgAugTransformSaltAndPepper(),
                lambda x: PIL.Image.fromarray(x),
                transforms.ToTensor(),
                data_normalization
            ])),
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True)

    test_loader = torch.utils.data.DataLoader(CaptionDataset(
        '../../../output_folder',
        data_name,
        'TEST',
        transform=transforms.Compose([
コード例 #4
0
    encoder, decoder = get_models(model_path, device)
    word_map, rev_word_map = get_word_map(
        args.run_local, '../../output_folder/WORDMAP_' + data_name + '.json')

    print('create pos dic for {} data'.format(args.data))

    if args.data == 'test':
        print('using cuda: {}', format(device))
        print('args.data = {}'.format(args.data))

        p = '/yoav_stg/gshalev/image_captioning/output_folder'
        coco_data = '../../output_folder' if args.run_local else p

        test_loader = torch.utils.data.DataLoader(CaptionDataset(
            coco_data,
            data_name,
            'TEST',
            transform=transforms.Compose([data_normalization])),
                                                  batch_size=1,
                                                  shuffle=True,
                                                  num_workers=1,
                                                  pin_memory=True)
        print('lev test_loader: {}'.format(len(test_loader)))

        gt_metric_dic = {'annotations': list()}
        hp_metric_dic = {'annotations': list()}
        for i, (image, caps, caplens, allcaps) in tqdm(enumerate(test_loader)):

            for ci in range(allcaps.shape[1]):
                gt = [rev_word_map[ind.item()]
                      for ind in allcaps[0][ci]][1:caplens[0][ci].item() - 1]
コード例 #5
0
def main():

    # section: settings
    global best_bleu4, epochs_since_improvement, start_epoch, data_name, word_map

    # section: fine tune
    if args.fine_tune_encoder and args.fine_tune_epochs == -1:
        raise Exception(
            'if "fine_tune_encoder" == true you must also specify "fine_tune_epochs" != -1'
        )

    # section: word map
    if not args.run_local:
        data_f = '/yoav_stg/gshalev/image_captioning/output_folder'
    else:
        data_f = data_folder

    word_map_file = os.path.join(data_f, 'WORDMAP_' + data_name + '.json')
    print('word_map_file: {}'.format(word_map_file))

    print('loading word map from path: {}'.format(word_map_file))
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)
    print('load word map COMPLETED')

    rev_word_map = {v: k for k, v in word_map.items()}

    # section: representation
    representations = get_embeddings(decoder_dim, len(word_map)).to(device)

    # section: cosine
    if not args.fixed:
        representations.requires_grad = True

    # section: Initialization
    if args.checkpoint is None:
        print('run a new model (No args.checkpoint)')
        decoder = DecoderWithoutAttention(attention_dim=attention_dim,
                                          embed_dim=emb_dim,
                                          decoder_dim=decoder_dim,
                                          vocab_size=len(word_map),
                                          device=device,
                                          dropout=dropout)

        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        if not args.fixed:
            decoder_optimizer.add_param_group({'params': representations})

        encoder = Encoder()
        #notice: fine to encoder
        encoder.fine_tune(True if args.fine_tune_encoder
                          and args.fine_tune_epochs == 0 else False)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr
        ) if args.fine_tune_encoder and args.fine_tune_epochs == 0 else None

    # section: load checkpoint
    else:
        print('run a model loaded from args.checkpoint')
        checkpoint = torch.load(args.checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = 0
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']

        if args.fine_tune_encoder and encoder_optimizer is None:
            print('----------loading model without encoder optimizer')
            encoder.fine_tune(args.fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)
        elif args.fine_tune_encoder and encoder_optimizer is not None:
            raise Exception('you are loading a model with encoder optimizer')

    # section: Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # section: wandb
    if not args.run_local:
        wandb.watch(decoder)

    # section: Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # section: dataloaders
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_f,
        data_name,
        'TRAIN',
        transform=transforms.Compose([data_normalization])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_f,
        data_name,
        'VAL',
        transform=transforms.Compose([data_normalization])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    val_loader_for_val = torch.utils.data.DataLoader(CaptionDataset(
        data_f,
        data_name,
        'VAL',
        transform=transforms.Compose([data_normalization])),
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=workers,
                                                     pin_memory=True)

    # section: Epochs
    print('starting epochs')
    for epoch in range(start_epoch, epochs):

        # section: terminate training after 20 epochs without improvment
        if epochs_since_improvement == 20:
            print('break after : epochs_since_improvement == 20')
            break

        # section: fine tune encoder
        if epoch == args.fine_tune_epochs:
            print('fine tuning after epoch({}) == args.fine_tune_epochs({})'.
                  format(epoch, args.fine_tune_epochs))
            encoder.fine_tune(args.fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

        # section: adjust LR after 8 epochs without improvment
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            print('!!!  ADJUST LR AFTER : epochs_since_improvement: {}'.format(
                epochs_since_improvement))
            adjust_learning_rate(decoder_optimizer, 0.8)

            # if args.checkpoint is not None:
            #     adjust_learning_rate(encoder_optimizer, 0.8)
            # elif args.fine_tune_encoder and epoch > args.fine_tune_epochs:
            #     print('------------------------------------epoch: {} fine tune lr encoder'.format(epoch))
            #     adjust_learning_rate(encoder_optimizer, 0.8)

        # section: train
        print(
            '--------------111111111-----------Start train----------epoch-{}'.
            format(epoch))
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              representations=representations)

        # section: eval
        print(
            '--------------2222222222-----------Start validation----------epoch-{}'
            .format(epoch))
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion,
                                rev_word_map=rev_word_map,
                                representations=representations)

        print('9999999999999- recent blue {}'.format(recent_bleu4))
        print(
            '--------------3333333333-----------Start val without teacher forcing----------epoch-{}'
            .format(epoch))
        caption_image_beam_search(encoder, decoder, val_loader_for_val,
                                  word_map, rev_word_map, representations)
        print(
            '!@#!@!#!#@!#@!#@ DONE WITH TRAIN VAL AND VAL WITHOUT TEACHER FORCING FOR EPOCH :{}'
            .format(epoch))

        # section: save model if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best, representations, args.runname)
コード例 #6
0
    print(len(word_map))
    print('create pos dic for {} data'.format(args.data))
    metrics_data_type_to_save = ['test', 'perturbed_jpeg', 'perturbed_salt']

    # section: build dic
    if args.data == 'test':
        print('using cuda: {}', format(device))
        print('args.data = {}'.format(args.data))

        data_path = '/yoav_stg/gshalev/image_captioning/output_folder'
        coco_data = '../../output_folder' if args.run_local else data_path

        test_loader = torch.utils.data.DataLoader(CaptionDataset(
            coco_data,
            data_name,
            'TEST',
            transform=transforms.Compose([data_normalization])),
                                                  batch_size=1,
                                                  shuffle=True,
                                                  num_workers=1,
                                                  pin_memory=True)
        print('lev test_loader: {}'.format(len(test_loader)))

        gt_metric_dic = {'annotations': list()}
        hp_metric_dic = {'annotations': list()}
        for i, (image, caps, caplens, allcaps) in tqdm(enumerate(test_loader)):
            for ci in range(allcaps.shape[1]):
                gt = [rev_word_map[ind.item()]
                      for ind in allcaps[0][ci]][1:caplens[0][ci].item() - 1]
                gt_metric_dic['annotations'].append({
コード例 #7
0
from dataset_loader.datasets import CaptionDataset
from standart_training.utils import *
from utils import *

data_folder = '/yoav_stg/gshalev/image_captioning/output_folder'  # folder with data files saved by create_input_files.py
# data_folder = 'output_folder'  # folder with data files saved by create_input_files.py
data_name = 'coco_5_cap_per_img_5_min_word_freq'
data_normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(CaptionDataset(
    data_folder,
    data_name,
    'TRAIN',
    transform=transforms.Compose([data_normalization])),
                                           batch_size=1,
                                           shuffle=True,
                                           num_workers=1,
                                           pin_memory=True)

word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')

with open(word_map_file, 'r') as j:
    word_map = json.load(j)

rev_word_map = {v: k for k, v in word_map.items()}

word_dic = dict()
for i, (imgs, caps, caplens) in tqdm(enumerate(train_loader)):
    if i % 1000 == 0:
        print('{}/{}'.format(i, len(train_loader)))