示例#1
0
                                    dataroot=args.data_folder)
    else:
        val_dset = VQAFeatureDataset('val',
                                     dictionary,
                                     args.relation_type,
                                     adaptive=args.adaptive,
                                     pos_emb_dim=args.imp_pos_emb_dim,
                                     dataroot=args.data_folder)
        train_dset = VQAFeatureDataset('train',
                                       dictionary,
                                       args.relation_type,
                                       adaptive=args.adaptive,
                                       pos_emb_dim=args.imp_pos_emb_dim,
                                       dataroot=args.data_folder)

    model = build_regat(val_dset, args).to(device)

    tfidf = None
    weights = None
    if args.tfidf:
        tfidf, weights = tfidf_from_questions(['train', 'val', 'test2015'],
                                              dictionary)
    model.w_emb.init_embedding(
        join(args.data_folder, 'glove/glove6b_init_300d.npy'), tfidf, weights)

    model = nn.DataParallel(model).to(device)

    if args.checkpoint != "":
        print("Loading weights from %s" % (args.checkpoint))
        if not os.path.exists(args.checkpoint):
            raise ValueError("No such checkpoint exists!")
示例#2
0
文件: eval.py 项目: ych133/VQA_ReGAT
                            'val', model_hps.relation_type,
                            adaptive=model_hps.adaptive,
                            dataroot=model_hps.data_folder)
        eval_dset = VQA_cp_Dataset(
                    args.split, dictionary, coco_train_features,
                    coco_val_features, adaptive=model_hps.adaptive,
                    pos_emb_dim=model_hps.imp_pos_emb_dim,
                    dataroot=model_hps.data_folder)
    else:
        eval_dset = VQAFeatureDataset(
                args.split, dictionary, model_hps.relation_type,
                adaptive=model_hps.adaptive,
                pos_emb_dim=model_hps.imp_pos_emb_dim,
                dataroot=model_hps.data_folder)

    model = build_regat(eval_dset, model_hps).to(device)

    model = nn.DataParallel(model).to(device)

    if args.checkpoint > 0:
        checkpoint_path = os.path.join(
                            args.output_folder,
                            f"model_{args.checkpoint}.pth")
    else:
        checkpoint_path = os.path.join(args.output_folder,
                                       f"model.pth")
    print("Loading weights from %s" % (checkpoint_path))
    if not os.path.exists(checkpoint_path):
        raise ValueError("No such checkpoint exists!")
    checkpoint = torch.load(checkpoint_path)
    state_dict = checkpoint.get('model_state', checkpoint)
示例#3
0
文件: test.py 项目: kanji95/VQA_ReGAT
                        action='store_true',
                        help='Enable bias term for relation labels \
                              in relation encoder')

    # can use config files
    parser.add_argument('--config', help='JSON config files')

    args = parse_with_config(parser)
    return args


if __name__ == '__main__':
    args = parse_args()
    n_device = torch.cuda.device_count()
    print("Found %d GPU cards for training" % (n_device))
    device = torch.device("cpu")
    batch_size = args.batch_size * n_device

    dictionary = Dictionary.load_from_file(
        join(args.data_folder, 'glove/dictionary.pkl'))
    val_dset = VQAFeatureDataset('val',
                                 dictionary,
                                 args.relation_type,
                                 adaptive=args.adaptive,
                                 pos_emb_dim=args.imp_pos_emb_dim,
                                 dataroot=args.data_folder)

    model = build_regat(val_dset, args)

    print(model)