Пример #1
0
def get_embs(opt,
             model,
             run,
             version,
             data_path,
             plot_folder,
             split='test',
             fold5=False,
             vocab_path="../vocab/"):
    # load vocabulary used by the model
    vocab = deserialize_vocab("{}/{}_vocab_{}.json".format(
        vocab_path, opt.data_name, version))
    opt.vocab_size = len(vocab)

    data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size,
                                  opt.workers, opt)

    img_embs, cap_embs, cap_lens, freqs = encode_data(model, data_loader)

    if not os.path.exists('{}/embs'.format(plot_folder)):
        os.makedirs('{}/embs'.format(plot_folder))

    torch.save(
        {
            'img_embs': img_embs,
            'cap_embs': cap_embs,
            "cap_lens": cap_lens,
            "freqs": freqs
        }, '{}/embs/embs_{}_{}.pth.tar'.format(plot_folder, run, version))
    print("Saved embeddings")
    return img_embs, cap_embs, cap_lens, freqs
Пример #2
0
def main(args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model_path = "{}/{}/seed1/checkpoint/model_best.pth.tar".format(
        args.run_folder, args.run)
    out_path = "{}/{}".format(args.out_folder, args.run)

    if not os.path.isdir(out_path):
        os.makedirs(out_path)

    print("LOADING MODEL")
    # load trained SCAN model
    model, opt = load_model(model_path, device)
    model.val_start()

    print("RETRIEVE VOCAB")
    # load vocabulary used by the model
    vocab = deserialize_vocab("{}/{}/{}_vocab_{}.json".format(
        opt.vocab_path, opt.clothing, opt.data_name, opt.version))
    opt.vocab_size = len(vocab)

    print("FILTER DATASETS")
    word_attn = {}
    for word in args.list_words:
        dpath = os.path.join(opt.data_path, opt.data_name, opt.clothing)
        # try:
        average_attn = calculate_attn(dpath, vocab, opt, word, model)
        # except:
        #     print("Word ({}) not found".format(word))
        #     continue
        plot_one(out_path, average_attn, word)
Пример #3
0
def evalrank(model_path, data_path=None, split='dev', fold5=False):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print(opt)
    if data_path is not None:
        opt.data_path = data_path
    # load vocabulary used by the model
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # construct model
    model = Local_Alignment(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size,
                                  opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs, cap_lens = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' % (img_embs.shape[0], cap_embs.shape[0]))

    img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
    print('Images: ', img_embs.shape)
    print('Captions: ', cap_embs.shape)

    start = time.time()
    sims = compute_sims(img_embs, cap_embs, cap_lens, opt, shard_size=128)
    print(sims[:20, :4])
    end = time.time()
    print("calculate similarity time:", end - start)

    print('Saving results...')
    sio.savemat('%s_relation.mat' % opt.data_name, {'similarity': sims})
    print('Saving success...')

    r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
    ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    ari = (ri[0] + ri[1] + ri[2]) / 3
    rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
    print("rsum: %.1f" % rsum)
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
    print("Average t2i Recall: %.1f" % ari)
    print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)

    torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='./data/',
                        help='path to datasets')
    parser.add_argument('--model_path',
                        default='./data/',
                        help='path to model')
    parser.add_argument('--split', default='test', help='val/test')
    parser.add_argument('--gpuid', default=0., type=str, help='gpuid')
    parser.add_argument('--fold5', action='store_true', help='fold5')
    opts = parser.parse_args()

    device_id = opts.gpuid
    print("use GPU:", device_id)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(device_id)
    device_id = 0
    torch.cuda.set_device(0)
    # load model and options
    checkpoint = torch.load(opts.model_path)
    opt = checkpoint['opt']
    opt.loss_verbose = False
    opt.split = opts.split
    opt.data_path = opts.data_path
    opt.fold5 = opts.fold5

    # load vocabulary used by the model
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # construct model
    model = SCAN(opt)
    model.cuda()
    model = nn.DataParallel(model)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = data.get_test_loader(opt.split, opt.data_name, vocab,
                                       opt.batch_size, opt.workers, opt)

    print(opt)
    print('Computing results...')

    evaluation.evalrank(model.module,
                        data_loader,
                        opt,
                        split=opt.split,
                        fold5=opt.fold5)
Пример #5
0
def main(args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model_path = "{}/{}/seed1/checkpoint/model_best.pth.tar".format(
        args.run_folder, args.run)
    out_path = "{}/{}".format(args.out_folder, args.run)

    if not os.path.isdir(out_path):
        os.makedirs(out_path)

    print("LOADING MODEL")
    # load trained SCAN model
    model, opt = load_model(model_path, device)
    model.val_start()

    print("RETRIEVE VOCAB")
    # load vocabulary used by the model
    vocab = deserialize_vocab("{}/{}/{}_vocab_{}.json".format(
        opt.vocab_path, opt.clothing, opt.data_name, opt.version))
    opt.vocab_size = len(vocab)

    word_attn = attn_per_word(args.list_words, opt, vocab, model)

    word_cos = {}
    for word_row in args.list_words:
        dpath = os.path.join(opt.data_path, opt.data_name, opt.clothing)

        loader_test, pos_test = retrieve_loader("test", opt, dpath, word_row,
                                                vocab)
        loader_train, pos_train = retrieve_loader("train", opt, dpath,
                                                  word_row, vocab)

        average_attn = word_attn[word_row]
        img_features = avg_features_img(average_attn, model, loader_train,
                                        loader_test)
        n_image = img_features.shape[0]

        temp_cos = {}
        for word_col in word_attn.keys():
            word_feature = avg_features_word(word_col, model, vocab)
            word_features = word_feature.expand(n_image, -1)
            cosine_scores = cosine_similarity(word_features, img_features)
            temp_cos[word_col] = torch.mean(cosine_scores).item()

        word_cos[word_row] = temp_cos

    print("PLOT ATTENTION")
    write_out(out_path, word_attn, "attention")
    write_table(out_path, word_cos)
    write_fig(out_path, word_cos, args.run)
Пример #6
0
def main():
    opt = opts.parse_opt()
    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    model = SGRAF(opt)

    # Train the Model
    best_rsum = 0

    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        r_sum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = r_sum > best_rsum
        best_rsum = max(r_sum, best_rsum)

        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
Пример #7
0
def start_experiment(opt, seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    print("Let's use", torch.cuda.device_count(), "GPUs!")
    print("Number threads:", torch.get_num_threads())

    # Load Vocabulary Wrapper, create dictionary that can switch between ids and words
    vocab = deserialize_vocab("{}/{}/{}_vocab_{}.json".format(
        opt.vocab_path, opt.clothing, opt.data_name, opt.version))

    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data_ken.get_loaders(opt.data_name, vocab,
                                                    opt.batch_size,
                                                    opt.workers, opt)

    # Construct the model
    model = SCAN(opt)

    # save hyperparameters in file
    save_hyperparameters(opt.logger_name, opt)

    best_rsum = 0
    start_epoch = 0
    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch'] + 1
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)
        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)

        last_epoch = False
        if epoch == (opt.num_epochs - 1):
            last_epoch = True

        # only save when best epoch, or last epoch for further training
        if is_best or last_epoch:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model': model.state_dict(),
                    'best_rsum': best_rsum,
                    'opt': opt,
                    'Eiters': model.Eiters,
                },
                is_best,
                last_epoch,
                filename='checkpoint_{}.pth.tar'.format(epoch),
                prefix=opt.model_name + '/')
    return best_rsum
Пример #8
0
def evalstack(model_path, data_path=None, split='dev', fold5=False, is_sparse=False):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.is_sparse = is_sparse
    print(opt)
    if data_path is not None:
        opt.data_path = data_path
        opt.vocab_path = "/media/ubuntu/data/chunxiao/vocab"

    # load vocabulary used by the model
    vocab = deserialize_vocab(os.path.join(
        opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # construct model
    model = GSMN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab,
                                  opt.batch_size, opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs, bbox, depends, cap_lens = encode_data(
        model, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] / 5, cap_embs.shape[0]))

    if not fold5:
        # no cross-validation, full evaluation
        img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
        start = time.time()
        sims = shard_xattn(model, img_embs, cap_embs, bbox,
                           depends, cap_lens, opt, shard_size=80)
        end = time.time()
        print("calculate similarity time:", end - start)

        return sims

    else:
        # 5fold cross-validation, only for MSCOCO
        sims_a = []
        for i in range(5):
            img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
            cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
            cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
            bbox_shard = bbox[i * 5000:(i + 1) * 5000:5]
            depend_shard = depends[i * 5000:(i + 1) * 5000]
            start = time.time()
            sims = shard_xattn(model, img_embs_shard, cap_embs_shard,
                               bbox_shard, depend_shard, cap_lens_shard, opt, shard_size=80)
            end = time.time()
            print("calculate similarity time:", end - start)

            sims_a.append(sims)

        return sims_a
Пример #9
0
def main(args):
    random.seed(17)

    min_l = args.min_l
    test_percentage = 0.1

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model_path = "{}/{}/seed1/checkpoint/model_best.pth.tar".format(
        args.run_folder, args.run)

    print("LOADING MODEL")
    # load trained SCAN model
    model, opt = load_model(model_path, device)
    model.val_start()

    print("RETRIEVE VOCAB")
    # load vocabulary used by the model
    vocab = deserialize_vocab("{}/{}/{}_vocab_{}.json".format(
        opt.vocab_path, opt.clothing, opt.data_name, opt.version))
    opt.vocab_size = len(vocab)

    args.list_words = [("black", "white"), ("black", "blue"),
                       ("multicolor", "floral"), ("lace", "jersey"),
                       ("silk", "crepe"), ("maxi", "midi"),
                       ("sheath", "shift"), ("sleeve", "sleeveless"),
                       ("long", "knee-length"), ("embroidered", "beaded")]
    scores = []
    for pair in args.list_words:
        word1 = pair[0]
        word2 = pair[1]

        dpath = os.path.join(opt.data_path, opt.data_name, opt.clothing)

        data_loader_train1, positions_train1 = retrieve_loader(
            "train", opt, dpath, word1, vocab)
        data_loader_train2, positions_train2 = retrieve_loader(
            "train", opt, dpath, word2, vocab)

        features1 = create_embs(data_loader_train1, model)
        features2 = create_embs(data_loader_train2, model)

        f1_best = []
        f1_worst = []
        best_segs = []
        worst_segs = []

        for i in range(5):
            result = perform_exp(model, features1, features2, test_percentage,
                                 min_l)
            f1_best.append(result[0])
            f1_worst.append(result[2])
            best_segs.append(result[1])
            worst_segs.append(result[3])
            l = result[4]

        best = np.mean(f1_best)
        worst = np.mean(f1_worst)
        best_std = np.std(f1_best)
        worst_std = np.std(f1_worst)

        best_seg = np.argmax(np.bincount(best_segs))
        worst_seg = np.argmax(np.bincount(worst_segs))

        scores.append((best, best_std, best_seg, best_segs, worst, worst_std,
                       worst_seg, worst_segs, l))

    print(">>>>>>>RUN: {}".format(args.run))
    print_scores(scores, args.list_words)
Пример #10
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', default='./data/',
                        help='path to datasets')
    parser.add_argument('--data_name', default='precomp',
                        help='{coco,f30k}_precomp')
    
    parser.add_argument('--val_data', default='precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument('--val_split', default='val',
                        help='split to validate results durin training')
    parser.add_argument('--val_batch_size', type=int, default=128,
                        help='batch size for validation')

    parser.add_argument('--adapt_data', default='precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument('--adapt_split', default='train',
                        help='split to perform adaptation to')
    parser.add_argument('--adapt_batch_size', type=int, default=128,
                        help='batch size for adaptation')
    
    parser.add_argument('--vocab_path', default='./vocab/',
                        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin', default=0.2, type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs', default=30, type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size', default=128, type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim', default=300, type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size', default=1024, type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip', default=2., type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers', default=1, type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate', default=.0002, type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update', default=15, type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers', default=10, type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step', default=10, type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step', default=500, type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name', default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name', default='./runs/runX/checkpoint',
                        help='Path to save the model.')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation', action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim', default=2048, type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm', action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm', action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument('--raw_feature_norm', default="clipped_l2norm",
                        help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax')
    parser.add_argument('--agg_func', default="LogSumExp",
                        help='LogSumExp|Mean|Max|Sum')
    parser.add_argument('--cross_attn', default="t2i",
                        help='t2i|i2t')
    parser.add_argument('--precomp_enc_type', default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--bi_gru', action='store_true',
                        help='Use bidirectional GRU.')
    parser.add_argument('--lambda_lse', default=6., type=float,
                        help='LogSumExp temp.')
    parser.add_argument('--lambda_softmax', default=9., type=float,
                        help='Attention softmax temperature.')

    ### Jonatas additional hyper ###
    parser.add_argument('--text_encoder', default='GRU',
                        help='[GRU|Conv].')
    parser.add_argument('--test_measure', default=None,
                        help='Similarity used for retrieval (None<same used for training>|cosine|order)')
    parser.add_argument('--add_data', action='store_true',
                        help='Wheter to use additional unlabeled data.')
    parser.add_argument('--log_images', action='store_true',
                        help='Wheter to use log images in tensorboard.')
    parser.add_argument('--noise', type=float, default=0.,
                        help='Ammont of noise for augmenting embeddings.')
    parser.add_argument('--kwargs', type=str, nargs='+', default=None,
                        help='Additional args for the model. Usage: argument:type:value ')


    ### Mean-teacher hyperparameters ###
    parser.add_argument('--ramp_lr', action='store_true',
                        help='Use the learning rate schedule from mean-teacher')
    parser.add_argument('--initial_lr', type=float, default=0.0006,
                        help='Initial learning_rate for rampup')
    parser.add_argument('--initial_lr_rampup', type=int, default=50,
                        help='Epoch for lr rampup')
    parser.add_argument('--consistency_weight', type=float, default=20.,
                        help='consistency weight (default: 20.).')
    parser.add_argument('--consistency_alpha', type=float, default=0.99,
                        help='Consistency alpha before ema_late_epoch')
    parser.add_argument('--consistency_alpha_late', type=float, default=0.999,
                        help='Consistency alpha after ema_late_epoch')
    parser.add_argument('--consistency_rampup', type=int, default=100,
                        help='Consistency rampup epoch')
    parser.add_argument('--ema_late_epoch', type=int, default=50,
                        help='When to change alpha variable for consistency weight')

    opt = parser.parse_args()

    # if opt.test_measure is None:
    #     opt.test_measure = opt.measure

    print('\n\n')
    print(opt)


    if opt.logger_name == '':
        writer = SummaryWriter()
        logpath = writer.file_writer.get_logdir()
        opt.logger_name = logpath
    else:
        writer = SummaryWriter(opt.logger_name)

    print('')
    print('')
    print('Outpath: ', opt.logger_name)


    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    # train_loader, val_loader = data.get_loaders(
    #     opt.data_name, vocab, opt.batch_size, opt.workers, opt)

    train_loader = data.get_loader(
        split='train',
        data_name=opt.data_name,
        batch_size=opt.batch_size,
        vocab=vocab,
        # tokenizer=tokenizer,        
        workers=opt.workers,
        opt=opt,
        adapt_set=False,
    )
    print('[OK] Train loader')

    val_loader = data.get_loader(
        data_name=opt.val_data,
        split=opt.val_split,
        batch_size=opt.val_batch_size,
        vocab=vocab,
        # tokenizer=tokenizer,                
        workers=opt.workers,
        opt=opt,
        adapt_set=False,
    )
    
    print('[OK] Val loader')

    adapt_loader = data.get_loader(
            split=opt.adapt_split,
            data_name=opt.adapt_data,
            batch_size=opt.adapt_batch_size,
            vocab=vocab,
            # tokenizer=tokenizer,            
            workers=opt.workers,
            opt=opt,
            adapt_set=True,            
        )
    
    print('[OK] Adapt loader')

    print('Train loader/dataset')
    print(train_loader.dataset.data_path, train_loader.dataset.split)
    print('Valid loader/dataset')
    print(val_loader.dataset.data_path, val_loader.dataset.split)
    print('Adapt loader/dataset')
    print(adapt_loader.dataset.data_path, adapt_loader.dataset.split)

    print('[OK] Loaders.')

    # Construct the model
    model = create_model(opt)
    model_ema = create_model(opt, ema=True)
    print(model.txt_enc)


    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
                  .format(opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model, writer)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        # train for one epoch
        train(opt, train_loader, adapt_loader, model, model_ema, epoch, val_loader, tb_writer=writer)

        # evaluate on validation set
        print('Validate EMA')
        rsum = validate(opt, val_loader, model, writer)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint({
            'epoch': epoch + 1,
            'model': model.state_dict(),
            'model_ema': model_ema.state_dict(),
            'best_rsum': best_rsum,
            'opt': opt,
            'Eiters': model.Eiters,
        }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/')
Пример #11
0
def evalrank(model_path, data_path=None, split='dev', fold5=False):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    with torch.no_grad():
        checkpoint = torch.load(model_path)
        opt = checkpoint['opt']
        print(opt)
        if data_path is not None:
            opt.data_path = data_path
        #=========================================
        if 'pos' not in opt:
            opt.pos = False
        #=========================================

        # load vocabulary used by the model
        vocab = deserialize_vocab(
            os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
        opt.vocab_size = len(vocab)

        # construct model
        model = SCAN(opt)

        # load model state
        model.load_state_dict(checkpoint['model'])

        print('Loading dataset')
        data_loader = get_test_loader(split, opt.data_name, vocab,
                                      opt.batch_size, opt.workers, opt)

        print('Computing results...')
        img_embs, cap_embs, cap_lens, cap_inds, img_inds, tag_masks = encode_data(
            model, data_loader)
        print('Images: %d, Captions: %d' %
              (img_embs.shape[0] / 5, cap_embs.shape[0]))

        if not fold5:
            # no cross-validation, full evaluation
            img_embs = np.array(
                [img_embs[i] for i in range(0, len(img_embs), 5)])
            start = time.time()
            if opt.cross_attn == 't2i':
                sims = shard_xattn_t2i(img_embs,
                                       cap_embs,
                                       cap_lens,
                                       opt,
                                       tag_masks,
                                       shard_size=128)
                # np.savez_compressed('test_sim_mat_f30k_t2i_AVG_glove.npz',sim=sims)
            elif opt.cross_attn == 'i2t':
                sims = shard_xattn_i2t(img_embs,
                                       cap_embs,
                                       cap_lens,
                                       opt,
                                       tag_masks,
                                       shard_size=128)
            else:
                raise NotImplementedError
            end = time.time()
            print("calculate similarity time:", end - start)

            r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
            ri, rti = t2i(img_embs,
                          cap_embs,
                          cap_lens,
                          sims,
                          return_ranks=True)
            ar = (r[0] + r[1] + r[2]) / 3
            ari = (ri[0] + ri[1] + ri[2]) / 3
            rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
            print("rsum: %.1f" % rsum)
            print("Average i2t Recall: %.1f" % ar)
            print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
            print("Average t2i Recall: %.1f" % ari)
            print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
        else:
            # 5fold cross-validation, only for MSCOCO
            results = []
            for i in range(5):
                img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
                cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
                cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
                start = time.time()
                if opt.cross_attn == 't2i':
                    sims = shard_xattn_t2i(img_embs_shard,
                                           cap_embs_shard,
                                           cap_lens_shard,
                                           opt,
                                           tag_masks,
                                           shard_size=128)
                elif opt.cross_attn == 'i2t':
                    sims = shard_xattn_i2t(img_embs_shard,
                                           cap_embs_shard,
                                           cap_lens_shard,
                                           opt,
                                           tag_masks,
                                           shard_size=128)
                else:
                    raise NotImplementedError
                end = time.time()
                print("calculate similarity time:", end - start)

                r, rt0 = i2t(img_embs_shard,
                             cap_embs_shard,
                             cap_lens_shard,
                             sims,
                             return_ranks=True)
                print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
                ri, rti0 = t2i(img_embs_shard,
                               cap_embs_shard,
                               cap_lens_shard,
                               sims,
                               return_ranks=True)
                print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)

                if i == 0:
                    rt, rti = rt0, rti0
                ar = (r[0] + r[1] + r[2]) / 3
                ari = (ri[0] + ri[1] + ri[2]) / 3
                rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
                print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
                results += [list(r) + list(ri) + [ar, ari, rsum]]

            print("-----------------------------------")
            print("Mean metrics: ")
            mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
            print("rsum: %.1f" % (mean_metrics[10] * 6))
            print("Average i2t Recall: %.1f" % mean_metrics[11])
            print("Image to text: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[:5])
            print("Average t2i Recall: %.1f" % mean_metrics[12])
            print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
                  mean_metrics[5:10])

        torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
def main():
    # Hyper Parameters
    
    opt = opts.parse_opt()

    device_id = opt.gpuid
    device_count = len(str(device_id).split(","))
    #assert device_count == 1 or device_count == 2
    print("use GPU:", device_id, "GPUs_count", device_count, flush=True)
    os.environ['CUDA_VISIBLE_DEVICES']=str(device_id)
    device_id = 0
    torch.cuda.set_device(0)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(
        opt.data_name, vocab, opt.batch_size, opt.workers, opt)

    # Construct the model
    model = SCAN(opt)
    model.cuda()
    model = nn.DataParallel(model)

     # Loss and Optimizer
    criterion = ContrastiveLoss(opt=opt, margin=opt.margin, max_violation=opt.max_violation)
    mse_criterion = nn.MSELoss(reduction="batchmean")
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate)

    # optionally resume from a checkpoint
    if not os.path.exists(opt.model_name):
        os.makedirs(opt.model_name)
    start_epoch = 0
    best_rsum = 0

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
                  .format(opt.resume, start_epoch, best_rsum))
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
    evalrank(model.module, val_loader, opt)

    print(opt, flush=True)
    
    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        message = "epoch: %d, model name: %s\n" % (epoch, opt.model_name)
        log_file = os.path.join(opt.logger_name, "performance.log")
        logging_func(log_file, message)
        print("model name: ", opt.model_name, flush=True)
        adjust_learning_rate(opt, optimizer, epoch)
        run_time = 0
        for i, (images, captions, lengths, masks, ids, _) in enumerate(train_loader):
            start_time = time.time()
            model.train()

            optimizer.zero_grad()

            if device_count != 1:
                images = images.repeat(device_count,1,1)

            score = model(images, captions, lengths, masks, ids)
            loss = criterion(score)

            loss.backward()
            if opt.grad_clip > 0:
                clip_grad_norm_(model.parameters(), opt.grad_clip)
            optimizer.step()
            run_time += time.time() - start_time
            # validate at every val_step
            if i % 100 == 0:
                log = "epoch: %d; batch: %d/%d; loss: %.4f; time: %.4f" % (epoch, 
                            i, len(train_loader), loss.data.item(), run_time / 100)
                print(log, flush=True)
                run_time = 0
            if (i + 1) % opt.val_step == 0:
                evalrank(model.module, val_loader, opt)

        print("-------- performance at epoch: %d --------" % (epoch))
        # evaluate on validation set
        rsum = evalrank(model.module, val_loader, opt)
        #rsum = -100
        filename = 'model_' + str(epoch) + '.pth.tar'
        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        save_checkpoint({
            'epoch': epoch + 1,
            'model': model.state_dict(),
            'best_rsum': best_rsum,
            'opt': opt,
        }, is_best, filename=filename, prefix=opt.model_name + '/')
Пример #13
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', default='./data/',
                        help='path to datasets')
    parser.add_argument('--data_name', default='precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument('--vocab_path', default='./vocab/',
                        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin', default=0.2, type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs', default=30, type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size', default=128, type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim', default=300, type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size', default=1024, type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip', default=2., type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers', default=1, type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate', default=.0002, type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update', default=15, type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers', default=10, type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step', default=100, type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step', default=15000, type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name', default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name', default='./runs/runX/checkpoint',
                        help='Path to save the models.')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation', action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim', default=2048, type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm', action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm', action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument('--bi_gru', action='store_true',
                        help='Use bidirectional GRU.')
    parser.add_argument('--lambda_softmax', default=9., type=float,
                        help='Attention softmax temperature.')
    parser.add_argument('--feat_dim', default=16, type=int,
                        help='Dimensionality of the similarity embedding.')
    parser.add_argument('--num_block', default=16, type=int,
                        help='Dimensionality of the similarity embedding.')
    parser.add_argument('--hid_dim', default=32, type=int,
                        help='Dimensionality of the hidden state during graph convolution.')
    parser.add_argument('--out_dim', default=1, type=int,
                        help='Dimensionality of the hidden state during graph convolution.')
    parser.add_argument('--is_sparse', action='store_true',
                        help='Whether models the text as a fully connected graph.')

    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(os.path.join(
        opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(
        opt.data_name, vocab, opt.batch_size, opt.workers, opt)

    # Construct the models
    model = GSMN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['models'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
                  .format(opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint({
            'epoch': epoch + 1,
            'models': model.state_dict(),
            'best_rsum': best_rsum,
            'opt': opt,
            'Eiters': model.Eiters,
        }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/')
Пример #14
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='/data2/yuanen/data/data_no_feature/',
                        help='path to datasets')
    parser.add_argument('--data_name',
                        default='f30k_precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument('--vocab_path',
                        default='./vocab/',
                        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=30,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    # parser.add_argument('--embed_size', default=512, type=int,
    #                     help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=15,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=10,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=1347,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/f30k_scan/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/f30k_scan/checkpoint',
                        help='Path to save the model.')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm',
                        action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument(
        '--raw_feature_norm',
        default="clipped_l2norm",
        help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax')
    parser.add_argument('--agg_func',
                        default="LogSumExp",
                        help='LogSumExp|Mean|Max|Sum')
    parser.add_argument('--cross_attn', default="t2i", help='t2i|i2t')
    parser.add_argument('--precomp_enc_type',
                        default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--bi_gru',
                        action='store_true',
                        help='Use bidirectional GRU.')
    parser.add_argument('--lambda_lse',
                        default=6.,
                        type=float,
                        help='LogSumExp temp.')
    parser.add_argument('--lambda_softmax',
                        default=9.,
                        type=float,
                        help='Attention softmax temperature.')
    parser.add_argument('--model_type',
                        default='SCAN',
                        type=str,
                        help='SCAN | vse++.')
    parser.add_argument('--pos',
                        default=False,
                        type=bool,
                        help='whether only consider viusal words.')

    #============================================================================
    #Add by vse++
    parser.add_argument('--measure',
                        default='cosine',
                        help='Similarity measure used (cosine|order)')
    parser.add_argument('--use_abs',
                        action='store_true',
                        help='Take the absolute value of embedding vectors.')
    parser.add_argument('--reset_train',
                        action='store_true',
                        help='Ensure the training is always done in '
                        'train mode (Not recommended).')
    parser.add_argument('--use_restval',
                        action='store_true',
                        help='Use the restval data for training on MSCOCO.')
    parser.add_argument('--finetune',
                        action='store_true',
                        help='Fine-tune the image encoder.')
    parser.add_argument('--cnn_type',
                        default='vgg19',
                        help="""The CNN used for image encoder
                        (e.g. vgg19, resnet152)""")
    #=============================================================================
    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    # vocab = deserialize_vocab(opt.vocab_path)
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    if opt.model_type == 'vse++':
        model = VSE(opt)
    else:
        model = SCAN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
Пример #15
0
def evalrank(input_string,
             img_feature,
             how_many,
             model_path,
             data_path=None,
             split='dev',
             fold5=False,
             gpu_num=None):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    s_t = time.time()
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print(opt)
    print("%s seconds taken to load checkpoint" % (time.time() - s_t))
    if data_path is not None:
        opt.data_path = data_path

    # construct model
    model = SCAN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    # local dir
    # opt.vocab_path = '/home/ivy/hard2/scan_data/vocab'
    # docker dir
    opt.vocab_path = '/scan/SCAN/data/vocab'

    # load vocabulary used by the model
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)
    # print("opt.vocab_size ", opt.vocab_size)

    print("Loading npy file")
    start_time = time.time()
    # local dir
    # img_embs = np.load('/home/ivy/hard2/scan_out/img_embs.npy')
    img_embs = img_feature
    # docker dir
    #img_embs = np.load('/scan/SCAN/numpy_data/img_embs.npy')
    print("%s seconds takes to load npy file" % (time.time() - start_time))

    captions = []
    captions.append(str(input_string))
    tokens = nltk.tokenize.word_tokenize(str(captions).lower().decode('utf-8'))
    caption = []
    caption.append(vocab('<start>'))
    caption.extend([vocab(token) for token in tokens])
    caption.append(vocab('<end>'))
    target = []
    for batch in range(opt.batch_size):
        target.append(caption)
    target = torch.Tensor(target).long()

    print('Calculating results...')
    start_time = time.time()
    cap_embs, cap_len = encode_data(model, target, opt.batch_size)
    cap_lens = cap_len[0]
    print("%s seconds takes to calculate results" % (time.time() - start_time))
    print("Caption length with start and end index : ", cap_lens)
    print('Images: %d, Captions: %d' % (img_embs.shape[0], cap_embs.shape[0]))

    if not fold5:
        img_embs = np.array(img_embs)
        start = time.time()
        if opt.cross_attn == 't2i':
            sims = shard_xattn_t2i(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        elif opt.cross_attn == 'i2t':
            sims = shard_xattn_i2t(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        else:
            raise NotImplementedError
        end = time.time()
        print("calculate similarity time:", end - start)

        # top_10 = np.argsort(sims, axis=0)[-10:][::-1].flatten()
        top_n = np.argsort(sims, axis=0)[-(how_many):][::-1].flatten()
        final_result = list(top_n)

        # 5fold cross-validation, only for MSCOCO
    else:
        for i in range(10):
            if i < 9:
                img_embs_shard = img_embs[i *
                                          (img_embs.shape[0] // 10):(i + 1) *
                                          (img_embs.shape[0] // 10)]
            else:
                img_embs_shard = img_embs[i * (img_embs.shape[0] // 10):]
            cap_embs_shard = cap_embs
            cap_lens_shard = cap_lens
            start = time.time()
            if opt.cross_attn == 't2i':
                sims = shard_xattn_t2i(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            elif opt.cross_attn == 'i2t':
                sims = shard_xattn_i2t(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            else:
                raise NotImplementedError
            end = time.time()
            print("calculate similarity time:", end - start)

            top_10 = np.argsort(sims, axis=0)[-10:][::-1].flatten()

            print("Top 10 list for iteration #%d : " % (i + 1) +
                  str(top_10 + 5000 * i))

        #     r, rt0 = i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True)
        #     print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
        #     ri, rti0 = t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True)
        #     print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
        #
        #     if i == 0:
        #         rt, rti = rt0, rti0
        #     ar = (r[0] + r[1] + r[2]) / 3
        #     ari = (ri[0] + ri[1] + ri[2]) / 3
        #     rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        #     print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
        #     results += [list(r) + list(ri) + [ar, ari, rsum]]
        #
        # print("-----------------------------------")
        # print("Mean metrics: ")
        # mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
        # print("rsum: %.1f" % (mean_metrics[10] * 6))
        # print("Average i2t Recall: %.1f" % mean_metrics[11])
        # print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
        #       mean_metrics[:5])
        # print("Average t2i Recall: %.1f" % mean_metrics[12])
        # print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
        #       mean_metrics[5:10])

    # torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
    return final_result
Пример #16
0
def main():
    # Hyper Parameters setting
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='/mnt/data/linkaiyi/scan/data/f30k_precomp',
                        help='path to datasets')
    parser.add_argument('--path_opt',
                        default='option/FusionNoattn_baseline.yaml',
                        type=str,
                        help='path to a yaml options file')
    parser.add_argument('--data_name',
                        default='flickr30k_splits',
                        help='{coco,f30k}_splits')
    parser.add_argument('--logger_name',
                        default='./log_2',
                        help='Path to save Tensorboard log.')
    parser.add_argument(
        '--vocab_path',
        default=
        '/home/linkaiyi/fusion_wangtan/Fusion_flickr/Fusion_10.28/vocab',
        help='Path to saved vocabulary json files.')
    parser.add_argument(
        '--model_name',
        default='/mnt/data/linkaiyi/mscoco/fusion/Fusion_flic/runs/checkpoint',
        help='Path to save the model.')
    parser.add_argument('--num_epochs',
                        default=120,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--workers',
                        default=2,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--lr_update',
                        default=20,
                        type=int,
                        help='Number of epochs to update the learning rate.')

    opt = parser.parse_args()
    if os.path.isdir(opt.logger_name):
        if click.confirm('Logs directory already exists in {}. Erase?'.format(
                opt.logger_name, default=False)):
            os.system('rm -r ' + opt.logger_name)
    tb_logger.configure(opt.logger_name, flush_secs=5)
    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    #########################################################################################
    # Create options
    #########################################################################################

    options = {'logs': {}, 'coco': {}, 'model': {'seq2vec': {}}, 'optim': {}}
    if opt.path_opt is not None:
        with open(opt.path_opt, 'r') as handle:
            options_yaml = yaml.load(handle)
        options = utils.update_values(options, options_yaml)

    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    vocab_word = sorted(vocab.word2idx.items(),
                        key=lambda x: x[1],
                        reverse=False)
    vocab_word = [tup[0] for tup in vocab_word]
    opt.vocab_size = len(vocab)

    # Create dataset, model, criterion and optimizer

    train_loader, val_loader = data.get_loaders(opt.data_path, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)
    model = models.factory(options['model'],
                           vocab_word,
                           cuda=True,
                           data_parallel=False)

    criterion = nn.CrossEntropyLoss(weight=torch.Tensor([1, 128])).cuda()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=options['optim']['lr'])

    print('Model has {} parameters'.format(utils.params_count(model)))

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            engine.validate(val_loader, model, criterion, optimizer,
                            opt.batch_size)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
    else:
        start_epoch = 0

    # Train the Model
    best_rsum = 0
    for epoch in range(start_epoch, opt.num_epochs):

        adjust_learning_rate(opt, options, optimizer, epoch)

        # train for one epoch

        engine.train(train_loader,
                     model,
                     criterion,
                     optimizer,
                     epoch,
                     print_freq=10)

        # evaluate on validation set
        rsum = engine.validate(val_loader, model, criterion, optimizer,
                               opt.batch_size)

        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': 'baseline',
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'options': options,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}{}.pth.tar'.format(epoch, best_rsum),
            prefix=opt.model_name + '/')
Пример #17
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='./data/',
                        help='path to datasets')
    parser.add_argument('--data_name',
                        default='precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument('--vocab_path',
                        default='./vocab/',
                        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=30,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=15,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=10,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=500,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/runX/checkpoint',
                        help='Path to save the model.')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm',
                        action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument(
        '--raw_feature_norm',
        default="clipped_l2norm",
        help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax')
    parser.add_argument('--t2i_agg_func',
                        default="Mean",
                        help='LogSumExp|Mean|Max|Sum')
    parser.add_argument('--i2t_agg_func',
                        default="LogSumExp",
                        help='LogSumExp|Mean|Max|Sum')
    parser.add_argument('--cross_attn', default="t2i", help='t2i|i2t')
    parser.add_argument('--precomp_enc_type',
                        default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--bi_gru',
                        action='store_true',
                        help='Use bidirectional GRU.')
    parser.add_argument('--lambda_lse',
                        default=5.,
                        type=float,
                        help='LogSumExp temp.')
    parser.add_argument('--lambda_softmax_avg',
                        default=9.,
                        type=float,
                        help='Attention softmax temperature.')
    parser.add_argument('--lambda_softmax_lse',
                        default=4.,
                        type=float,
                        help='Attention softmax temperature.')
    parser.add_argument('--caption_vocab',
                        default=500,
                        type=float,
                        help='caption vocabulary size')
    parser.add_argument(
        '--caption_np',
        default='/home/wangzheng/neurltalk/SCAN_t2i_nn500xin/vocab/',
        help='caption vocabulary size')

    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    vocab_tag = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab_tag.json' % opt.data_name))

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                vocab_tag, opt.batch_size,
                                                opt.workers, opt)

    captions_w = numpy.load(opt.caption_np + 'caption_np.npy')
    captions_w = torch.from_numpy(captions_w)
    captions_w = captions_w.cuda()

    # Construct the model
    model = SCAN(opt, captions_w)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
    else:
        start_epoch = 0

    # Train the Model
    best_rsum = 0
    for epoch in range(start_epoch, opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)
        adjust_learning_rate1(opt, model.optimizer1, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
Пример #18
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='/home/csl/DatasetsPy2/data',
                        help='path to datasets')
    parser.add_argument('--data_name',
                        default='f30k_precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument('--vocab_path',
                        default='./vocab/',
                        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=15,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')

    parser.add_argument('--embed_dim',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')

    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=15,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=10,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=100,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=2000000,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/f30k_precomp/',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/f30k_precomp/',
                        help='Path to save the model.')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm',
                        action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument('--precomp_enc_type',
                        default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--lambda_softmax',
                        default=20.,
                        type=float,
                        help='Attention softmax temperature.')
    parser.add_argument('--focal_type', default="equal", help='equal|prob')

    parser.add_argument('--measure',
                        default='cosine',
                        help='Similarity measure used (cosine|order)')

    parser.add_argument('--use_BatchNorm',
                        action='store_false',
                        help='Whether to use BN.')
    parser.add_argument('--activation_type',
                        default='tanh',
                        help='choose type of activation functions.')
    parser.add_argument('--dropout_rate',
                        default=0.4,
                        type=float,
                        help='dropout rate.')
    # parser.add_argument('--measure', default='cosine',
    # help='Similarity measure used (cosine|order)')
    parser.add_argument(
        '--feature_fuse_type',
        default='weight_sum',
        help=
        'choose the fusing type for raw feature and attribute feature (multiple|concat|adap_sum|weight_sum))'
    )

    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    model = BFAN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
Пример #19
0
    cap_embs = np.array([cap_embs[0]]) # one text to multiply images, so cap_embs has the same

    sims = shard_xattn_t2i(img_embs, whole_img_embs, cap_embs, cap_lens, final_cap_embs, opt, shard_size=128)
    print "Sims (t2i)", sims
    sims2 = shard_xattn_i2t(img_embs, cap_embs, cap_lens, opt, shard_size=128)
    print "Sims (i2t)", sims2
    result = {'sim_t2i':reverse_sim(sims),'sim_i2t':reverse_sim(sims2)}
    return json.dumps(result)

g_model = None
g_split_name = 'test_server'
g_vocab = None

if __name__ == "__main__":
    model_path = sys.argv[1]
    port = 5091

    if len(sys.argv) > 2:
        port = int(sys.argv[2])

    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print "Option", opt
    g_model = SCAN(opt)
    g_model.load_state_dict(checkpoint['model'])
    g_vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_sent_vocab.json' % opt.data_name))
    opt.vocab_size = len(g_vocab)
    print "Vocab size", opt.vocab_size
    app.run(host='0.0.0.0', port=port, debug=False)
Пример #20
0
def evalrank(model_path, data_path=None, split='dev', fold5=False):

    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print(opt)
    if data_path is not None:
        opt.data_path = data_path

    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    captions_w = np.load(opt.caption_np + 'caption_np.npy')
    captions_w = torch.from_numpy(captions_w)

    captions_w = captions_w.cuda()

    model = SCAN(opt, captions_w)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size,
                                  opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs, cap_lens = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] / 5, cap_embs.shape[0]))

    if not fold5:

        img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
        start = time.time()
        if opt.cross_attn == 't2i':
            sims = shard_xattn_t2i(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        elif opt.cross_attn == 'i2t':
            sims = shard_xattn_i2t(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   shard_size=128)
        elif opt.cross_attn == 'all':
            sims, label = shard_xattn_all(model,
                                          img_embs,
                                          cap_embs,
                                          cap_lens,
                                          opt,
                                          shard_size=128)
        else:
            raise NotImplementedError
        end = time.time()
        print("calculate similarity time:", end - start)
        np.save('sim_stage1', sims)

        r, rt = i2t(label,
                    img_embs,
                    cap_embs,
                    cap_lens,
                    sims,
                    return_ranks=True)
        ri, rti = t2i(label,
                      img_embs,
                      cap_embs,
                      cap_lens,
                      sims,
                      return_ranks=True)
        ar = (r[0] + r[1] + r[2]) / 3
        ari = (ri[0] + ri[1] + ri[2]) / 3
        rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        print("rsum: %.1f" % rsum)
        print("Average i2t Recall: %.1f" % ar)
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
        print("Average t2i Recall: %.1f" % ari)
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
    else:

        results = []
        for i in range(5):
            img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
            cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
            cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
            start = time.time()
            if opt.cross_attn == 't2i':
                sims = shard_xattn_t2i(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            elif opt.cross_attn == 'i2t':
                sims = shard_xattn_i2t(img_embs_shard,
                                       cap_embs_shard,
                                       cap_lens_shard,
                                       opt,
                                       shard_size=128)
            else:
                raise NotImplementedError
            end = time.time()
            print("calculate similarity time:", end - start)

            r, rt0 = i2t(img_embs_shard,
                         cap_embs_shard,
                         cap_lens_shard,
                         sims,
                         return_ranks=True)
            print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
            ri, rti0 = t2i(img_embs_shard,
                           cap_embs_shard,
                           cap_lens_shard,
                           sims,
                           return_ranks=True)
            print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)

            if i == 0:
                rt, rti = rt0, rti0
            ar = (r[0] + r[1] + r[2]) / 3
            ari = (ri[0] + ri[1] + ri[2]) / 3
            rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
            print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
            results += [list(r) + list(ri) + [ar, ari, rsum]]

        print("-----------------------------------")
        print("Mean metrics: ")
        mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
        print("rsum: %.1f" % (mean_metrics[10] * 6))
        print("Average i2t Recall: %.1f" % mean_metrics[11])
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[:5])
        print("Average t2i Recall: %.1f" % mean_metrics[12])
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[5:10])

    torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='/data',
                        help='path to datasets')
    parser.add_argument('--data_name',
                        default='precomp',
                        help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k')
    parser.add_argument('--vocab_path',
                        default='./vocab/',
                        help='Path to saved vocabulary pickle files.')
    parser.add_argument('--margin',
                        default=0.05,
                        type=float,
                        help='loss margin.')
    parser.add_argument('--temperature',
                        default=14,
                        type=int,
                        help='loss temperature.')
    parser.add_argument('--num_epochs',
                        default=9,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=2048,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=4,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=10,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=500,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='runs/runX',
                        help='Path to save the model and Tensorboard log.')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--measure',
                        default='cosine',
                        help='Similarity measure used (cosine|order)')
    parser.add_argument('--use_abs',
                        action='store_true',
                        help='Take the absolute value of embedding vectors.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--seed', default=1, type=int, help='random seed.')
    parser.add_argument('--use_atten', action='store_true', help='use_atten')
    parser.add_argument('--lambda_softmax',
                        default=9.,
                        type=float,
                        help='Attention softmax temperature.')
    parser.add_argument('--use_box', action='store_true', help='use_box')
    parser.add_argument('--use_label', action='store_true', help='use_label')
    parser.add_argument('--use_mmd', action='store_true', help='use_mmd')
    parser.add_argument('--score_path',
                        default='../user_data/score.npy',
                        type=str)

    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    #tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    #vocab = pickle.load(open(os.path.join(
    #    opt.vocab_path, '%s_vocab.pkl' % opt.data_name), 'rb'))
    #vocab = deserialize_vocab(os.path.join(opt.vocab_path, 'kdd2020_caps_vocab_train_val_threshold2.json'))
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    stoppath = os.path.join(opt.vocab_path, 'stopwords.txt')
    f_stop = open(stoppath, 'r')
    stops = f_stop.readlines()
    stopwords = []
    for sw in stops:
        sw = sw.strip()  #.encode('utf-8').decode('utf-8')
        stopwords.append(sw)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                stopwords, opt.batch_size,
                                                opt.workers, opt, True)

    # Construct the model
    model = VSRN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='/data',
                        help='path to datasets')
    parser.add_argument('--data_name',
                        default='precomp',
                        help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k')
    parser.add_argument('--vocab_path',
                        default='./vocab/',
                        help='Path to saved vocabulary pickle files.')
    parser.add_argument('--margin',
                        default=0.05,
                        type=float,
                        help='loss margin.')
    parser.add_argument('--temperature',
                        default=14,
                        type=int,
                        help='loss temperature.')
    parser.add_argument('--num_epochs',
                        default=7,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=2048,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=4,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=10,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=500,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='runs/runX',
                        help='Path to save the model and Tensorboard log.')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--measure',
                        default='cosine',
                        help='Similarity measure used (cosine|order)')
    parser.add_argument('--use_abs',
                        action='store_true',
                        help='Take the absolute value of embedding vectors.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--seed', default=1, type=int, help='random seed.')
    parser.add_argument('--use_atten', action='store_true', help='use_atten')
    parser.add_argument('--use_box', action='store_true', help='use_box')
    parser.add_argument('--use_label', action='store_true', help='use_label')
    parser.add_argument('--lambda_softmax',
                        default=9.,
                        type=float,
                        help='Attention softmax temperature.')
    parser.add_argument('--use_mmd', action='store_true', help='use_mmd')
    parser.add_argument('--score_path',
                        default='../user_data/score.npy',
                        type=str)

    opt = parser.parse_args()
    print(opt)

    set_seed(opt.seed)

    if not os.path.exists(opt.logger_name):
        os.mkdir(opt.logger_name)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    stoppath = os.path.join(opt.vocab_path, 'stopwords.txt')
    f_stop = open(stoppath, 'r')
    stops = f_stop.readlines()
    stopwords = []
    for sw in stops:
        sw = sw.strip()  #.encode('utf-8').decode('utf-8')
        stopwords.append(sw)

    # Load data loaders
    if opt.resume:
        train_loader, val_loader = data_finetune.get_loaders(
            opt.data_name, vocab, stopwords, opt.batch_size, opt.workers, opt)
    else:
        train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                    stopwords, opt.batch_size,
                                                    opt.workers, opt)

    # Construct the model
    model = VSRN(opt)

    # optionally resume from a checkpoint
    start_epoch = 0
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = 4
            #start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' epoch {}".format(
                opt.resume, start_epoch))
            #validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    best_rerank_rsum = 0

    for epoch in range(start_epoch, opt.num_epochs):
        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train_loader.dataset.initial()
        best_rsum, best_rerank_rsum = train(opt, train_loader, model, epoch,
                                            val_loader, best_rsum,
                                            best_rerank_rsum)

        # evaluate on validation set
        rsum, rerank_rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        rerank_is_best = rerank_rsum > best_rerank_rsum
        best_rsum = max(rsum, best_rsum)
        best_rerank_rsum = max(rerank_rsum, best_rerank_rsum)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'best_rerank_rsum': best_rerank_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            rerank_is_best,
            prefix=opt.logger_name + '/')
Пример #23
0
def evalrank(model_path,
             model_path2,
             data_path=None,
             split='dev',
             fold5=False):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    checkpoint = torch.load(model_path)
    checkpoint2 = torch.load(model_path2)
    opt = checkpoint['opt']
    opt2 = checkpoint2['opt']
    print(opt)
    if data_path is not None:
        opt.data_path = data_path

    # load vocabulary used by the model
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # construct model
    model = BFAN(opt)
    model2 = BFAN(opt2)
    # load model state
    model.load_state_dict(checkpoint['model'])
    model2.load_state_dict(checkpoint2['model'])
    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size,
                                  opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs, cap_lens = encode_data(model, data_loader)
    img_embs2, cap_embs2, cap_lens2 = encode_data(model2, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] / 5, cap_embs.shape[0]))

    if not fold5:
        # no cross-validation, full evaluation
        img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)])
        start = time.time()

        sims = shard_xattn(img_embs, cap_embs, cap_lens, opt, shard_size=128)

        end = time.time()
        print("calculate similarity time:", end - start)

        batch_size = img_embs.shape[0]
        r, rt = i2t(batch_size, sims, return_ranks=True)
        ri, rti = t2i(batch_size, sims, return_ranks=True)
        ar = (r[0] + r[1] + r[2]) / 3
        ari = (ri[0] + ri[1] + ri[2]) / 3
        rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        print("rsum: %.1f" % rsum)
        print("Average i2t Recall: %.1f" % ar)
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
        print("Average t2i Recall: %.1f" % ari)
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)

        img_embs2 = np.array(
            [img_embs2[i] for i in range(0, len(img_embs2), 5)])
        start = time.time()

        sims2 = shard_xattn(img_embs2,
                            cap_embs2,
                            cap_lens2,
                            opt2,
                            shard_size=128)

        end = time.time()
        print("calculate similarity time:", end - start)

        batch_size = img_embs2.shape[0]
        r, rt = i2t(batch_size, sims2, return_ranks=True)
        ri, rti = t2i(batch_size, sims2, return_ranks=True)
        ar = (r[0] + r[1] + r[2]) / 3
        ari = (ri[0] + ri[1] + ri[2]) / 3
        rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        print("rsum: %.1f" % rsum)
        print("Average i2t Recall: %.1f" % ar)
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
        print("Average t2i Recall: %.1f" % ari)
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)

        simsall = (sims + sims2) / 2

        end = time.time()
        print("calculate similarity time:", end - start)

        batch_size = img_embs2.shape[0]
        r, rt = i2t(batch_size, simsall, return_ranks=True)
        ri, rti = t2i(batch_size, simsall, return_ranks=True)
        ar = (r[0] + r[1] + r[2]) / 3
        ari = (ri[0] + ri[1] + ri[2]) / 3
        rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        print("rsum: %.1f" % rsum)
        print("Average i2t Recall: %.1f" % ar)
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
        print("Average t2i Recall: %.1f" % ari)
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)

    else:
        # 5fold cross-validation, only for MSCOCO
        results = []
        for i in range(5):
            img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5]
            cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000]
            cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000]
            start = time.time()
            sims = shard_xattn(img_embs_shard,
                               cap_embs_shard,
                               cap_lens_shard,
                               opt,
                               shard_size=128)

            end = time.time()
            print("calculate similarity time:", end - start)

            batch_size = img_embs_shard.shape[0]
            r, rt0 = i2t(batch_size, sims, return_ranks=True)
            print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
            ri, rti0 = t2i(batch_size, sims, return_ranks=True)
            print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)

            if i == 0:
                rt, rti = rt0, rti0
            ar = (r[0] + r[1] + r[2]) / 3
            ari = (ri[0] + ri[1] + ri[2]) / 3
            rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
            print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
            results += [list(r) + list(ri) + [ar, ari, rsum]]

            img_embs_shard = img_embs2[i * 5000:(i + 1) * 5000:5]
            cap_embs_shard = cap_embs2[i * 5000:(i + 1) * 5000]
            cap_lens_shard = cap_lens2[i * 5000:(i + 1) * 5000]
            start = time.time()
            sims2 = shard_xattn(img_embs_shard,
                                cap_embs_shard,
                                cap_lens_shard,
                                opt2,
                                shard_size=128)
            simsall = (sims + sims2) / 2

            end = time.time()
            print("calculate similarity time:", end - start)

            batch_size = img_embs_shard.shape[0]
            r, rt0 = i2t(batch_size, sims2, return_ranks=True)
            print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
            ri, rti0 = t2i(batch_size, sims2, return_ranks=True)
            print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)

            r, rt0 = i2t(batch_size, simsall, return_ranks=True)
            print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
            ri, rti0 = t2i(batch_size, simsall, return_ranks=True)
            print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)

            if i == 0:
                rt, rti = rt0, rti0
            ar = (r[0] + r[1] + r[2]) / 3
            ari = (ri[0] + ri[1] + ri[2]) / 3
            rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
            print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
            results += [list(r) + list(ri) + [ar, ari, rsum]]

        print("-----------------------------------")
        print("Mean metrics: ")
        mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
        print("rsum: %.1f" % (mean_metrics[10] * 6))
        print("Average i2t Recall: %.1f" % mean_metrics[11])
        print("Image to text: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[:5])
        print("Average t2i Recall: %.1f" % mean_metrics[12])
        print("Text to image: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[5:10])

    torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
Пример #24
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='./data/',
                        help='path to datasets')
    parser.add_argument('--data_name',
                        default='precomp',
                        help='{coco, f30k}_precomp')
    parser.add_argument('--vocab_path',
                        default='./vocab/',
                        help='Path to saved vocablulary json files')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=30,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=64,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--word_embed_size',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding')
    parser.add_argument('--hidden_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=0.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=15,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=10,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Number of steps to print and record the log')
    parser.add_argument('--val_step',
                        default=1000,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/runX/checkpoint',
                        help='Path to save the model')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default:none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss')
    parser.add_argument('--img_size',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding')
    parser.add_argument('--norm',
                        action='store_true',
                        help='normalize the text and image embedding')
    parser.add_argument(
        '--norm_func',
        default='clipped_l2norm',
        help='clipped_leaky_l2norm|clipped_l2norm|l2norm|'
        'clipped_leaky_l1norm|clipped_l1norm|l1norm|no_norm|softmax')
    parser.add_argument('--agg_func',
                        default='Mean',
                        help='LogSumExp|Mean|Max|Sum')
    parser.add_argument('--bi_gru',
                        action='store_true',
                        help='Use bidirectional GRU.')
    parser.add_argument('--lambda_lse',
                        default=6.,
                        type=float,
                        help='LogSumExp temp')
    parser.add_argument('--lambda_softmax',
                        default=9.,
                        type=float,
                        help='Attention softmax temperature')
    parser.add_argument(
        '--activation_func',
        default='relu',
        help='activation function: relu|gelu|no_activation_fun')
    parser.add_argument(
        '--alpha',
        default=0.5,
        type=float,
        help='the weight of final score between i2t score and t2i score')
    parser.add_argument('--use_abs',
                        action='store_true',
                        help='take the absolute value of embedding vectors')

    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # load Vocabulary Wrapper
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    model = CARRN(opt)
    best_rsum = 0
    start_epoch = 0

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch'] + 1
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
Пример #25
0
def evalrank(model_path,
             run,
             data_path=None,
             split='dev',
             fold5=False,
             vocab_path="../vocab/"):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    print(opt)
    if data_path is not None:
        opt.data_path = data_path

    # load vocabulary used by the model
    vocab = deserialize_vocab("{}{}/{}_vocab_{}.json".format(
        vocab_path, opt.clothing, opt.data_name, opt.version))
    opt.vocab_size = len(vocab)
    print(opt.vocab_size)
    # construct model
    model = SCAN(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size,
                                  opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs, cap_lens = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' % (img_embs.shape[0], cap_embs.shape[0]))

    t2i_switch = True
    sims = shard_xattn_t2i(model,
                           img_embs,
                           cap_embs,
                           cap_lens,
                           opt,
                           shard_size=128)

    # r = (r1, r2, r5, medr, meanr), rt= (ranks, top1)
    r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
    ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    ari = (ri[0] + ri[1] + ri[2]) / 3
    rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
    print("rsum: %.1f" % rsum)
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
    print("Average t2i Recall: %.1f" % ari)
    print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)

    save_dir = "plots_laenen"

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    torch.save({
        'rt': rt,
        'rti': rti,
        "t2i_switch": t2i_switch
    }, '{}/ranks_{}_{}.pth.tar'.format(save_dir, run, opt.version))
    return rt, rti, r, ri
Пример #26
0
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    dataset = 'cub'  # deep_fashion, cub 'f30k' or 'coco'
    exp_type = 'mlf'  # scan or mlf
    parser.add_argument('--data_path', default='/ivi/ilps/personal/mbiriuk/repro/data',
                        help='path to datasets')
    parser.add_argument('--data_name', default=f'{dataset}_{exp_type}_precomp',
                        help='{coco,f30k, cub}_precomp')
    parser.add_argument('--vocab_path', default='./vocab/',
                        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin', default=0.2, type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs', default=15, type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size', default=128, type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim', default=300, type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size', default=1024, type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip', default=2., type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers', default=1, type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate', default=.0002, type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update', default=15, type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers', default=10, type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step', default=10, type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step', default=500, type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name', default=f'./runs/{dataset}_{exp_type}/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name', default=f'./runs/{dataset}_{exp_type}/log',
                        help='Path to save the model.')
    parser.add_argument('--resume'
                        , default=''
                        # , default='/Users/mhendriksen/Desktop/repositories/SCAN/runs/f30k_scan/checkpoint_9.pth.tar'
                        , type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation', action='store_true', default='bi_gru',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim', default=2048, type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm', action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm', action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument('--raw_feature_norm', default="clipped_l2norm",
                        help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax')
    parser.add_argument('--agg_func', default="LogSumExp",
                        help='LogSumExp|Mean|Max|Sum')
    parser.add_argument('--cross_attn', default="t2i",
                        help='t2i|i2t')
    parser.add_argument('--precomp_enc_type', default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--bi_gru', action='store_true',
                        help='Use bidirectional GRU.')
    parser.add_argument('--lambda_lse', default=6., type=float,
                        help='LogSumExp temp.')
    parser.add_argument('--lambda_softmax', default=9., type=float,
                        help='Attention softmax temperature.')
    opt = parser.parse_args()
    logging.info(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    # Load data loaders
    train_loader, val_loader = data.get_loaders(
        opt.data_name, vocab, opt.batch_size, opt.workers, opt)

    # Construct the model
    model = SCAN(opt)

    best_rsum = 0
    start_epoch = 0
    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume, map_location=device)
            start_epoch = checkpoint['epoch'] + 1
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
                  .format(opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint({
            'epoch': epoch,
            'model': model.state_dict(),
            'best_rsum': best_rsum,
            'opt': opt,
            'Eiters': model.Eiters,
        }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/')