Пример #1
0
def test(args, data, test_csv, hdf5):
    data_all = test_csv.values
    test_len = data_all.shape[0]
    test_set = TrainSet(data_all, hdf5, data.word_matrix, data.word2idx,
                        data.ans2idx)
    test_loader = DataLoader(test_set,
                             batch_size=args.batch_size,
                             shuffle=False,
                             pin_memory=True)

    model = TgifModel()
    model = nn.DataParallel(model).cuda()
    checkpoint = load_checkpoint(
        osp.join(args.model_dir, 'checkpoint_best.pth.tar'))
    model.module.load_state_dict(checkpoint['state_dict'])
    model.eval()

    acc_all = 0
    for j, d in enumerate(test_loader):
        video_features, question_embeds, ql, ans_labels = d
        imgs = Variable(video_features.cuda())
        # question_embeds, ql = pack_paded_questions(question_embeds, ql)
        question_embeds = torch.stack(question_embeds, 1).cuda()
        questions_embed = Variable(question_embeds)
        ans_labels = Variable(ans_labels.cuda())
        ans_scores = model(imgs, questions_embed, ql)
        _, preds = torch.max(ans_scores, 1)
        acc = torch.sum((preds == ans_labels).data)
        acc_all += acc
        if j % args.print_freq == 0:
            print('test img {} acc is : {}'.format(j, acc))
    print('test acc is : {:06f}'.format(int(acc_all) / int(test_len)))
    return acc_all
Пример #2
0
def get_data(data_dir, batch_size):

    root_train = osp.join(data_dir, 'train')
    # attrs_file = open(osp.join(data_dir, 'DatasetA/attrs.json'))
    cls_file = open(osp.join(data_dir, 'cls_0919.json'))
    train_file = open(osp.join(data_dir, 'train_list_0919.txt'))
    # attrs = json.load(attrs_file)
    cls_labels = json.load(cls_file)
    # data_set = json.load(data_file)
    train_set = train_file.readlines()
    print('train imgs', len(train_set))

    root_test = osp.join(data_dir, 'DatasetA/test')
    val_file = open(osp.join(data_dir, 'DatasetA/submit.txt'), 'r')
    val_set = val_file.readlines()
    # val_cls_list, val_attrs = get_val_attrs()
    val_cls_list, val_attrs = get_val_attrs()
    attrs = get_word_embed()

    attr_mat = torch.zeros((164, 300))
    for key in cls_labels:
        a = attrs[key]
        a = map(float, a)
        attr_mat[cls_labels[key]] = torch.Tensor(a)

    train_transforms = transforms.Compose([
        # transforms.RandomResizedCrop(224),
        # transforms.Resize(225),
        transforms.Resize(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_transforms = transforms.Compose([
        # transforms.RandomResizedCrop(224),
        # transforms.Resize(225),
        transforms.Resize(128),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_loader = DataLoader(TrainSet(train_set,
                                       attrs=attrs,
                                       labels=cls_labels,
                                       root=root_train,
                                       transform=train_transforms),
                              batch_size=batch_size,
                              shuffle=True,
                              pin_memory=True)
    val_loader = DataLoader(ValSet(val_set,
                                   root=root_test,
                                   transform=test_transforms),
                            batch_size=batch_size / 2,
                            shuffle=False,
                            pin_memory=True)
    # val_cls_list, val_attrs = get_val_attrs()
    print('get train data loader')
    return train_loader, val_loader, val_cls_list, val_attrs, attr_mat
Пример #3
0
def get_data(data_dir, batch_size):

    root_train = osp.join(data_dir, 'train')
    attrs_file = open(osp.join(data_dir, 'DatasetA/attrs.json'))
    cls_file = open(osp.join(data_dir, 'DatasetA/cls.json'))
    data_file = open(osp.join(data_dir, 'meta.json'))
    attrs = json.load(attrs_file)
    cls_labels = json.load(cls_file)
    data_set = json.load(data_file)
    train_set = data_set['images']
    print('train imgs', len(train_set))

    root_test = osp.join(data_dir, 'DatasetA/test')
    val_file = open(osp.join(data_dir, 'DatasetA/submit.txt'), 'r')
    val_set = val_file.readlines()
    val_cls_list, val_attrs = get_val_attrs()

    train_transforms = transforms.Compose([
        # transforms.RandomResizedCrop(224),
        # transforms.Resize(225),
        transforms.Resize(256),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_transforms = transforms.Compose([
        # transforms.RandomResizedCrop(224),
        # transforms.Resize(225),
        transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_loader = DataLoader(TrainSet(train_set,
                                       attrs=attrs,
                                       labels=cls_labels,
                                       root=root_train,
                                       transform=train_transforms),
                              batch_size=batch_size,
                              shuffle=True,
                              pin_memory=True)
    val_loader = DataLoader(ValSet(val_set,
                                   root=root_test,
                                   transform=test_transforms),
                            batch_size=batch_size / 2,
                            shuffle=False,
                            pin_memory=True)
    # val_cls_list, val_attrs = get_val_attrs()
    print('get train data loader')
    return train_loader, val_loader, val_cls_list, val_attrs
Пример #4
0
def main(args):
    data = Dataset(task_type='FrameQA',
                   data_dir='/home/stage/yuan/tgif-qa/code/dataset/tgif/')
    hdf5 = '/home/stage/yuan/tgif-qa/code/dataset/tgif/features/TGIF_RESNET_pool5.hdf5'
    data_csv, test_df, ans2idx, idx2ans, word2idx, idx2word, word_matrix = data.get_train_data(
    )
    data_all = data_csv.values
    np.random.shuffle(data_all)
    val_data = data_all[:7392]
    train_data = data_all[7392:]
    val_len = val_data.shape[0]

    train_set = TrainSet(train_data, hdf5, data.word_matrix, data.word2idx,
                         data.ans2idx)
    val_set = TrainSet(val_data, hdf5, data.word_matrix, data.word2idx,
                       data.ans2idx)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=False,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            shuffle=False,
                            pin_memory=True)

    test(args, data, test_df, hdf5)
    model = TgifModel()
    print(model)
    model = nn.DataParallel(model).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       model.parameters()),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    best = 0
    for epoch in range(0, args.epochs):
        model.train()
        losses = AverageMeter()
        corrects = AverageMeter()
        for i, d in enumerate(train_loader):
            video_features, question_embeds, ql, ans_labels = d
            imgs = Variable(video_features.cuda())
            # question_embeds, ql = pack_paded_questions(question_embeds, ql)
            question_embeds = torch.stack(question_embeds, 1).cuda()
            questions_embed = Variable(question_embeds)
            ans_labels = Variable(ans_labels.cuda())
            ans_scores = model(imgs, questions_embed, ql)
            _, preds = torch.max(ans_scores, 1)
            loss = criterion(ans_scores, ans_labels)

            losses.update(loss.data[0], ans_labels.size(0))
            corrects.update(torch.sum((preds == ans_labels).data))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % args.print_freq == 0:
                print(
                    'Epoch: [{}][{}/{}]\t Loss {:.6f} ({:.6f})\t acc {} ({})\t'
                    .format(epoch, i + 1, len(train_loader), losses.val,
                            losses.avg, corrects.val, corrects.avg))
        save_checkpoint(
            {
                'state_dict': model.module.state_dict(),
                'epoch': epoch + 1,
                'best_top1': 0,
            },
            False,
            fpath=osp.join(args.model_dir, 'checkpoint.pth.tar'))
        acc = valid(args, val_loader)
        print('valid acc {:.6f}'.format(int(acc) / int(val_len)))
        if acc > best:
            best = acc
            save_checkpoint(
                {
                    'state_dict': model.module.state_dict(),
                    'epoch': epoch + 1,
                    'best_top1': 0,
                },
                False,
                fpath=osp.join(args.model_dir, 'checkpoint_best.pth.tar'))
            print('save model best at ep {}'.format(epoch))

    test(args, data, test_df, hdf5)