def main(_):
    # Build the inference graph.
    g = tf.Graph()
    with g.as_default():
        model = gradcam_wrapper.GradCamWrapper()
        restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
                                                   FLAGS.checkpoint_path)
    #g.finalize()
    save_path = osp.join(FLAGS.save_path,
                         osp.basename(FLAGS.model_name) + '_gt')
    if FLAGS.save_path != "" and not osp.isdir(save_path):
        os.makedirs(save_path)

    # Create the vocabulary.
    vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
    man_id = vocab.word_to_id('man')
    woman_id = vocab.word_to_id('woman')
    #person_id = vocab.word_to_id('person') # if we want to additionally process "person" words

    of = open(FLAGS.img_path, 'r')
    image_ids = of.read().split('\n')
    if image_ids[-1] == '':
        image_ids = image_ids[0:-1]

    json_path = coco_dir + '/annotations/captions_val2014.json'
    json_data = json.load(open(json_path, 'r'))
    json_dict = {}
    for entry in json_data['annotations']:
        image_id = entry['image_id']
        if str(image_id) not in image_ids: continue
        if image_id not in json_dict:
            caption = entry['caption']
            caption = caption.lower()
            tokens = caption.split(' ')
            if '_man' in FLAGS.img_path: look_for = 'man'
            elif '_woman' in FLAGS.img_path: look_for = 'woman'
            else: assert (False)
            if look_for in tokens:
                json_dict[image_id] = caption
        if len(json_dict) == 500: break

    image_ids = json_dict.keys()
    #assert(len(image_ids)==500)

    with tf.Session(graph=g) as sess:
        # Load the model from checkpoint.
        restore_fn(sess)

        global_index = 0
        for i, image_id in enumerate(image_ids):
            image_id = int(image_id)
            sys.stdout.write('\r%d/%d' % (i, len(image_ids)))
            filename = coco_dir + '/images/val2014/COCO_val2014_' + "%012d" % (
                image_id) + '.jpg'
            with tf.gfile.GFile(filename, "r") as f:
                image = f.read()
            caption = json_dict[image_id]
            print(caption)
            if caption[-1] == '.':
                caption = caption[0:-1]
            tokens = caption.split(' ')
            tokens.insert(0, '<S>')
            encoded_tokens = [vocab.word_to_id(w) for w in tokens]
            man_ids = [i for i, c in enumerate(encoded_tokens) if c == man_id]
            woman_ids = [
                i for i, c in enumerate(encoded_tokens) if c == woman_id
            ]
            #person_ids = [i for i, c in enumerate(encoded_tokens) if c == person_id]
            if not (man_ids or woman_ids):  # or person_ids):
                assert (False)
            else:
                for wid in man_ids:
                    if FLAGS.save_path != "":
                        save_path_pre = save_path + '/' + "%06d" % (
                            global_index) + '_'
                    else:
                        save_path_pre = ""
                    model.process_image(sess,
                                        image,
                                        encoded_tokens,
                                        filename,
                                        vocab,
                                        word_index=wid - 1,
                                        word_id=man_id,
                                        save_path=save_path_pre)
                    global_index += 1
                for wid in woman_ids:
                    if FLAGS.save_path != "":
                        save_path_pre = save_path + '/' + "%06d" % (
                            global_index) + '_'
                    else:
                        save_path_pre = ""
                    model.process_image(sess,
                                        image,
                                        encoded_tokens,
                                        filename,
                                        vocab,
                                        word_index=wid - 1,
                                        word_id=woman_id,
                                        save_path=save_path_pre)
                    global_index += 1
                #for wid in person_ids:
                #  if FLAGS.save_path != "":
                #    save_path_pre = save_path + '/' + "%06d" % (global_index) + '_'
                #  else:
                #    save_path_pre = ""
                #  model.process_image(sess, image, encoded_tokens, filename, vocab, word_index=wid-1, word_id=person_id, save_path=save_path_pre)
                #  global_index += 1
            import gc
            gc.collect()
def main(_):
    # Build the inference graph.
    g = tf.Graph()
    import ipdb
    ipdb.set_trace()
    with g.as_default():
        model = gradcam_wrapper.GradCamWrapper()
        restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
                                                   FLAGS.checkpoint_path)
    #g.finalize()
    save_path = osp.join(FLAGS.save_path, osp.basename(FLAGS.json_path)[0:-5])
    if FLAGS.save_path != "" and not osp.isdir(save_path):
        os.makedirs(save_path)

    # Create the vocabulary.
    vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
    man_id = vocab.word_to_id('man')
    woman_id = vocab.word_to_id('woman')
    person_id = vocab.word_to_id('person')

    #filenames = glob.glob(FLAGS.input_files)
    json_data = json.load(open(FLAGS.json_path, 'r'))
    json_dict = {}
    for entry in json_data:
        image_id = entry['image_id']
        if image_id not in json_dict:
            json_dict[image_id] = entry['caption']
        else:
            ipdb.set_trace()
    of = open(FLAGS.img_path, 'r')
    image_ids = of.read().split('\n')
    if image_ids[-1] == '':
        image_ids = image_ids[0:-1]

    with tf.Session(graph=g) as sess:
        # Load the model from checkpoint.
        restore_fn(sess)

        global_index = 0
        for i, image_id in enumerate(image_ids):
            image_id = int(image_id)
            sys.stdout.write('\r%d/%d' % (i, len(image_ids)))
            filename = 'im2txt/data/mscoco/images/val2014/COCO_val2014_' + "%012d" % (
                image_id) + '.jpg'
            with tf.gfile.GFile(filename, "r") as f:
                image = f.read()
            if image_id not in json_dict:
                continue
            caption = json_dict[image_id]
            print(caption)
            if caption[-1] == '.':
                caption = caption[0:-1]
            tokens = caption.split(' ')
            tokens.insert(0, '<S>')
            encoded_tokens = [vocab.word_to_id(w) for w in tokens]
            man_ids = [i for i, c in enumerate(encoded_tokens) if c == man_id]
            woman_ids = [
                i for i, c in enumerate(encoded_tokens) if c == woman_id
            ]
            person_ids = [
                i for i, c in enumerate(encoded_tokens) if c == person_id
            ]
            if not (man_ids or woman_ids or person_ids):
                # nothing to do
                continue
            else:
                for wid in man_ids:
                    if FLAGS.save_path != "":
                        save_path_pre = save_path + '/' + "%06d" % (
                            global_index) + '_'
                    else:
                        save_path_pre = ""
                    model.process_image(sess,
                                        image,
                                        encoded_tokens,
                                        filename,
                                        vocab,
                                        word_index=wid - 1,
                                        word_id=man_id,
                                        save_path=save_path_pre)
                    global_index += 1
                for wid in woman_ids:
                    if FLAGS.save_path != "":
                        save_path_pre = save_path + '/' + "%06d" % (
                            global_index) + '_'
                    else:
                        save_path_pre = ""
                    model.process_image(sess,
                                        image,
                                        encoded_tokens,
                                        filename,
                                        vocab,
                                        word_index=wid - 1,
                                        word_id=woman_id,
                                        save_path=save_path_pre)
                    global_index += 1
                for wid in person_ids:
                    if FLAGS.save_path != "":
                        save_path_pre = save_path + '/' + "%06d" % (
                            global_index) + '_'
                    else:
                        save_path_pre = ""
                    model.process_image(sess,
                                        image,
                                        encoded_tokens,
                                        filename,
                                        vocab,
                                        word_index=wid - 1,
                                        word_id=person_id,
                                        save_path=save_path_pre)
                    global_index += 1
            import gc
            gc.collect()
示例#3
0
def main(_):
    # Build the inference graph.
    g = tf.Graph()
    with g.as_default():
        model = gradcam_wrapper.GradCamWrapper()
        restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
                                                   FLAGS.checkpoint_path)
    #g.finalize()
    save_path = osp.join(FLAGS.save_path, osp.basename(FLAGS.json_path)[0:-5])
    if FLAGS.save_path != "" and not osp.isdir(save_path):
        os.makedirs(save_path)

    # Create the vocabulary.
    vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
    man_id = vocab.word_to_id(FLAGS.male_word)
    woman_id = vocab.word_to_id(FLAGS.female_word)
    # person_id = vocab.word_to_id(FLAGS.neutral_word) # if we want to additionally process "person" words

    json_data = json.load(open(FLAGS.json_path, 'r'))
    json_dict = {}
    for entry in json_data:
        file_id = entry['filename']
        if file_id not in json_dict:
            caption = entry['caption']
            caption = caption.lower()
            json_dict[str(file_id)] = caption

    filenames = glob.glob(FLAGS.img_path)
    print(json_dict)

    with tf.Session(graph=g) as sess:
        # Load the model from checkpoint.
        restore_fn(sess)

        global_index = 0
        for i, filename in enumerate(filenames):
            with tf.gfile.GFile(filename, "r") as f:
                image = f.read()
            caption = json_dict[filename]
            print(caption)
            if caption[-1] == '.':
                caption = caption[0:-1]
            tokens = caption.split(' ')
            tokens.insert(0, '<S>')
            encoded_tokens = [vocab.word_to_id(w) for w in tokens]
            man_ids = [i for i, c in enumerate(encoded_tokens) if c == man_id]
            woman_ids = [
                i for i, c in enumerate(encoded_tokens) if c == woman_id
            ]
            #      person_ids = [i for i, c in enumerate(encoded_tokens) if c == person_id]
            if not (man_ids or woman_ids):
                # nothing to do
                continue
            else:
                for wid in man_ids:
                    if FLAGS.save_path != "":
                        save_path_pre = save_path + '/' + "%06d" % (
                            global_index) + '_'
                    else:
                        save_path_pre = ""
                    model.process_image(sess,
                                        image,
                                        encoded_tokens,
                                        filename,
                                        vocab,
                                        word_index=wid - 1,
                                        word_id=man_id,
                                        save_path=save_path_pre)
                    global_index += 1
                for wid in woman_ids:
                    if FLAGS.save_path != "":
                        save_path_pre = save_path + '/' + "%06d" % (
                            global_index) + '_'
                    else:
                        save_path_pre = ""
                    model.process_image(sess,
                                        image,
                                        encoded_tokens,
                                        filename,
                                        vocab,
                                        word_index=wid - 1,
                                        word_id=woman_id,
                                        save_path=save_path_pre)
                    global_index += 1

#       for wid in person_ids:
#        if FLAGS.save_path != "":
#         save_path_pre = save_path + '/' + "%06d" % (global_index) + '_'
#         else:
#           save_path_pre = ""
#         model.process_image(sess, image, encoded_tokens, filename, vocab, word_index=wid-1, word_id=person_id, save_path=save_path_pre)
#         global_index += 1
            import gc
            gc.collect()