Exemplo n.º 1
0
def main(args):
    dataset_splitBy = args.dataset + '_' + args.splitBy
    if not osp.isdir(osp.join('cache/feats/', dataset_splitBy)):
        os.makedirs(osp.join('cache/feats/', dataset_splitBy))

    # Image Directory
    if 'coco' in dataset_splitBy or 'combined' in dataset_splitBy:
        IMAGE_DIR = 'data/images/mscoco/images/train2014'
    elif 'clef' in dataset_splitBy:
        IMAGE_DIR = 'data/images/saiapr_tc-12'
    elif 'sunspot' in dataset_splitBy:
        IMAGE_DIR = 'data/images/SUNRGBD'
    else:
        print('No image directory prepared for ', args.dataset)
        sys.exit(0)

    # load dataset
    data_json = osp.join('cache/prepro', dataset_splitBy, 'data.json')
    data_h5 = osp.join('cache/prepro', dataset_splitBy, 'data.h5')
    loader = Loader(data_json, data_h5)
    images = loader.images
    anns = loader.anns
    num_anns = len(anns)
    assert sum([len(image['ann_ids']) for image in images]) == num_anns

    # load mrcn model
    mrcn = inference_no_imdb.Inference(args)

    # feats_h5
    # feats_h5 = osp.join('cache/feats', dataset_splitBy, args.file_name)
    file_name = '%s_%s_%s_ann_feats.h5' % (args.net_name, args.imdb_name,
                                           args.tag)
    feats_h5 = osp.join('cache/feats', dataset_splitBy, 'mrcn', file_name)

    f = h5py.File(feats_h5, 'w')
    pool5_set = f.create_dataset('pool5', (num_anns, 1024), dtype=np.float32)
    fc7_set = f.create_dataset('fc7', (num_anns, 2048), dtype=np.float32)

    # extract
    feats_dir = '%s_%s_%s' % (args.net_name, args.imdb_name, args.tag)
    head_feats_dir = osp.join('cache/feats/', dataset_splitBy, 'mrcn',
                              feats_dir)
    for i, image in enumerate(images):
        image_id = image['image_id']
        net_conv, im_info = image_to_head(head_feats_dir, image_id)
        ann_ids = image['ann_ids']
        for ann_id in ann_ids:
            ann = loader.Anns[ann_id]
            ann_pool5, ann_fc7 = ann_to_pool5_fc7(mrcn, ann, net_conv, im_info)
            ann_h5_id = ann['h5_id']
            pool5_set[ann_h5_id] = ann_pool5.data.cpu().numpy()
            fc7_set[ann_h5_id] = ann_fc7.data.cpu().numpy()
        if i % 20 == 0:
            print('%s/%s done.' % (i + 1, len(images)))

    f.close()
    print('%s written.' % feats_h5)
Exemplo n.º 2
0
def main():
    data_json = '/data/ryli/kcli/refer_seg/MAttNet/cache/prepro/refcoco_unc/data.json'
    data_h5 = '/data/ryli/kcli/refer_seg/MAttNet/cache/prepro/refcoco_unc/data.h5'
    loader = Loader(data_json, data_h5)

    parser = SubjectParser(len(loader.word_to_ix), 512, 512, 90)
    parser.cuda()

    learning_rate = 1e-4
    optimizer = torch.optim.Adam(parser.parameters(), lr=learning_rate)

    def lossFun(loader, optimizer, parser, input_label, class_label):
        parser.train()
        optimizer.zero_grad()

        class_pred, rnn_output, rnn_hidden = parser(input_label)

        cls_error = F.cross_entropy(class_pred, class_label)
        #loss = confidence * torch.exp(F.threshold(-cls_error, -1.0, -1.0)) + 0.1 * (1 - confidence) * cls_error
        loss = cls_error

        loss.backward()
        optimizer.step()

        return loss.data[0], cls_error, class_pred

    sent_count = 0
    avg_accuracy = 0
    for ref_id in loader.Refs:
        ref = loader.Refs[ref_id]
        class_label = loader.Anns[ref['ann_id']]['category_id'] - 1
        for sent_id in ref['sent_ids']:
            sent = loader.sentences[sent_id]
            if len(sent['tokens']) > 2: continue
            sent_count += 1
            input_sent = [t for t in sent['tokens'] if t in loader.word_to_ix else '<UNK>']
            input_label = [[loader.word_to_ix[word] for word in input_sent]]
            input_label = Variable(torch.cuda.LongTensor(input_label))
            class_label = Variable(torch.cuda.LongTensor(class_label))
            loss, cls_error, cls_pred = lossFun(loader, optimizer, parser,
                input_label, class_label)
            _, pred = torch.max(cls_pred, 1)
            if pred == class_label:
                avg_accuracy = avg_accuracy * 0.99 + 0.01
            else:
                avg_accuracy *= 0.999
            if sent_count % 100 == 0:
                print('Sentence %d: id(%d)' % (sent_count, sent_id))
                print('  %-12s: loss = %f, cls_error = %f, avg_accuracy = %.4f, lr = %.2E' %
                    (word, loss, cls_error, avg_accuracy, learning_rate))
Exemplo n.º 3
0
def Loader(*args, **kwargs):
    from loaders.loader import Loader

    return Loader(*args, **kwargs)