dataroot=args.data_folder) else: val_dset = VQAFeatureDataset('val', dictionary, args.relation_type, adaptive=args.adaptive, pos_emb_dim=args.imp_pos_emb_dim, dataroot=args.data_folder) train_dset = VQAFeatureDataset('train', dictionary, args.relation_type, adaptive=args.adaptive, pos_emb_dim=args.imp_pos_emb_dim, dataroot=args.data_folder) model = build_regat(val_dset, args).to(device) tfidf = None weights = None if args.tfidf: tfidf, weights = tfidf_from_questions(['train', 'val', 'test2015'], dictionary) model.w_emb.init_embedding( join(args.data_folder, 'glove/glove6b_init_300d.npy'), tfidf, weights) model = nn.DataParallel(model).to(device) if args.checkpoint != "": print("Loading weights from %s" % (args.checkpoint)) if not os.path.exists(args.checkpoint): raise ValueError("No such checkpoint exists!")
'val', model_hps.relation_type, adaptive=model_hps.adaptive, dataroot=model_hps.data_folder) eval_dset = VQA_cp_Dataset( args.split, dictionary, coco_train_features, coco_val_features, adaptive=model_hps.adaptive, pos_emb_dim=model_hps.imp_pos_emb_dim, dataroot=model_hps.data_folder) else: eval_dset = VQAFeatureDataset( args.split, dictionary, model_hps.relation_type, adaptive=model_hps.adaptive, pos_emb_dim=model_hps.imp_pos_emb_dim, dataroot=model_hps.data_folder) model = build_regat(eval_dset, model_hps).to(device) model = nn.DataParallel(model).to(device) if args.checkpoint > 0: checkpoint_path = os.path.join( args.output_folder, f"model_{args.checkpoint}.pth") else: checkpoint_path = os.path.join(args.output_folder, f"model.pth") print("Loading weights from %s" % (checkpoint_path)) if not os.path.exists(checkpoint_path): raise ValueError("No such checkpoint exists!") checkpoint = torch.load(checkpoint_path) state_dict = checkpoint.get('model_state', checkpoint)
action='store_true', help='Enable bias term for relation labels \ in relation encoder') # can use config files parser.add_argument('--config', help='JSON config files') args = parse_with_config(parser) return args if __name__ == '__main__': args = parse_args() n_device = torch.cuda.device_count() print("Found %d GPU cards for training" % (n_device)) device = torch.device("cpu") batch_size = args.batch_size * n_device dictionary = Dictionary.load_from_file( join(args.data_folder, 'glove/dictionary.pkl')) val_dset = VQAFeatureDataset('val', dictionary, args.relation_type, adaptive=args.adaptive, pos_emb_dim=args.imp_pos_emb_dim, dataroot=args.data_folder) model = build_regat(val_dset, args) print(model)