Exemplo n.º 1
0
    )
    args = parse_args()

    torch.backends.cudnn.benchmark = True

    if args.task == 'vqa':
        from train import evaluate
        dict_path = 'data/dictionary.pkl'
        dictionary = Dictionary.load_from_file(dict_path)
        eval_dset = VQAFeatureDataset('val', dictionary, adaptive=True)

    elif args.task == 'flickr':
        from train_flickr import evaluate
        dict_path = 'data/flickr30k/dictionary.pkl'
        dictionary = Dictionary.load_from_file(dict_path)
        eval_dset = Flickr30kFeatureDataset('test', dictionary)
        args.op = ''
        args.gamma = 1

    n_device = torch.cuda.device_count()
    batch_size = args.batch_size * n_device

    constructor = 'build_%s' % args.model
    model = getattr(base_model, constructor)(eval_dset, args.num_hid, args.op,
                                             args.gamma, args.task).cuda()
    model_data = torch.load(args.input + '/model' +
                            ('_epoch%d' %
                             args.epoch if 0 < args.epoch else '') + '.pth')

    model = nn.DataParallel(model).cuda()
    model.load_state_dict(model_data.get('model_state', model_data))
Exemplo n.º 2
0
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = True

    if args.task == 'vqa':
        from train import train
        dict_path = 'data/dictionary.pkl'
        dictionary = Dictionary.load_from_file(dict_path)
        train_dset = VQAFeatureDataset('train', dictionary, adaptive=True)
        val_dset = VQAFeatureDataset('val', dictionary, adaptive=True)
        w_emb_path = 'data/glove6b_init_300d.npy'

    elif args.task == 'flickr':
        from train_flickr import train
        dict_path = 'data/flickr30k/dictionary.pkl'
        dictionary = Dictionary.load_from_file(dict_path)
        train_dset = Flickr30kFeatureDataset('train', dictionary)
        val_dset = Flickr30kFeatureDataset('val', dictionary)
        w_emb_path = 'data/flickr30k/glove6b_init_300d.npy'
        args.op = ''
        args.gamma = 1
        args.tfidf = False

    utils.create_dir(args.output)
    logger = utils.Logger(os.path.join(args.output, 'args.txt'))
    logger.write(args.__repr__())

    batch_size = args.batch_size

    constructor = 'build_%s' % args.model
    model = getattr(base_model, constructor)(train_dset, args.num_hid, args.op,
                                             args.gamma, args.task).cuda()