Ejemplo n.º 1
0
#! /usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch_geometric.data import Data, Batch

from . import get_vocab
from . import create_seq2seq_gnn_net

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
from rsmlkit.logging import set_default_level

set_default_level(logging.INFO)
# logger = get_logger(__file__)

class Seq2seqParser():
    """Model interface for seq2seq parser"""

    def __init__(self, opt):
        """Initialize a Seq2seq model by either new initialization
        for pre-training or loading a pre-trained checkpoint for
        fine-tuning using reinforce
        """
        self.opt = opt              # See class TrainOptions for details
        self.vocab = get_vocab(opt)

        if opt.load_checkpoint_path is not None:
            self.load_checkpoint(opt)
Ejemplo n.º 2
0
def main(args):
    """
        Save nx.graph (Gss, Gts,...) and corresponding torch_geometric.data.PairData
        (via clevr_parse embedder api).
    """
    if args.is_debug:
        set_default_level(10)
    is_directed_graph = args.is_directed_graph
    logger.debug(f"Parser flag is_directed_graph = {is_directed_graph}")
    graph_parser = clevr_parser.Parser(
        backend="spacy",
        model='en_core_web_sm',
        has_spatial=True,
        has_matching=True).get_backend(identifier='spacy')
    embedder = clevr_parser.Embedder(
        backend='torch', parser=graph_parser).get_backend(identifier='torch')
    raw_questions, img_scenes = get_questions_and_parsed_scenes(
        args.input_questions_json, args.input_parsed_img_scenes_json)
    logger.info('| importing questions from %s' % args.input_question_h5)
    input_questions = h5py.File(args.input_question_h5, 'r')
    #N = len(input_questions['questions'])

    # Baseline Entities #
    questions, programs, answers, question_families, orig_idxs, img_idxs = [], [], [], [], [], []
    family_count = np.zeros(90)

    # Graphs and Embeddings #
    data_s_list = []  # List [torch_geometric.data.Data]
    data_t_list = []  # List [torch_geometric.data.Data]

    filename = get_output_filename(args)
    __all_question_families: np.ndarray = input_questions['question_families'][
        ()]
    __all_enc_questions: np.ndarray = input_questions['questions'][()]
    __all_img_indices: np.ndarray = input_questions['image_idxs'][()]
    logger.debug(f"__all_question_families len {len(__all_question_families)}")

    # Sample N items for each 90 families #
    fam2indices = get_question_fam_to_indices(args)
    M = len(fam2indices.keys())  # 90
    N = args.n_questions_per_family  # 50
    max_sample = N * M  # 90 * 50 = 4500
    family_count = np.zeros(M)  # family_count = Counter()

    # TODO: accumulating values here need to be parallelized, and joined write ex-post
    num_skipped = 0  # Counter for tracking num of samples skipped
    for fam_idx, i_samples in enumerate(fam2indices):
        all_fam_samples = fam2indices[fam_idx]
        logger.debug(
            f"Question_family {fam_idx} has {len(all_fam_samples)} samples to choose {N} samples"
        )
        N_question_sample_indices = np.random.choice(
            all_fam_samples, N, replace=False)  # N.b seed is fixed
        assert len(N_question_sample_indices) == N
        # TODO: parallelize this iteration loop
        for i in N_question_sample_indices:
            try:
                img_idx = __all_img_indices[i]
                logger.debug(
                    f"\tProcessing Image - {img_idx} from fam_idx {fam_idx}: {i} of {i_samples}"
                )
                img_scene = list(
                    filter(lambda x: x['image_index'] == img_idx,
                           img_scenes))[0]
            except IndexError as ie:
                logger.warning(f"For {img_idx}: {ie}")
                num_skipped += 1
                continue
            try:
                Gt, t_doc = graph_parser.get_doc_from_img_scene(
                    img_scene, is_directed_graph=is_directed_graph)
                X_t, ei_t, e_attr_t = embedder.embed_t(
                    img_idx, args.input_parsed_img_scenes_json)
            except AssertionError as ae:
                logger.warning(f"AssertionError Encountered: {ae}")
                logger.warning(
                    f"[{img_idx}] Excluding images with > 10 objects")
                num_skipped += 1
                continue
            if Gt is None and ("SKIP" in t_doc):
                # If the derendering pipeline failed, then just skip the
                # scene, don't process the labels (and text_scenes) for the image
                logger.warning(f"Got None img_doc at image_index: {img_idx}")
                print(f"Skipping all text_scenes for imgage idx: {img_idx}")
                num_skipped += 1
                continue
            q_idx = input_questions['orig_idxs'][i]
            q_obj = list(
                filter(lambda x: x['question_index'] == q_idx,
                       raw_questions))[0]
            assert q_obj['image_index'] == img_idx
            s = q_obj['question']
            try:
                Gs, s_doc = graph_parser.parse(
                    s, return_doc=True, is_directed_graph=is_directed_graph)
                X_s, ei_s, e_attr_s = embedder.embed_s(s)
            except ValueError as ve:
                logger.warning(f"ValueError Encountered: {ve}")
                logger.warning(f"Skipping question: {s} for {img_fn}")
                num_skipped += 1
                continue
            if Gs is None and ("SKIP" in s_doc):
                logger.warning(
                    "Got None as Gs and 'SKIP' in Gs_embd. (likely plural with CLEVR_OBJS label) "
                )
                logger.warning(f"SKIPPING processing {s}  at {img_idx}")
                num_skipped += 1
                continue

            data_s = ClevrData(x=X_s, edge_index=ei_s, edge_attr=e_attr_s)
            data_t = ClevrData(x=X_t, edge_index=ei_t, edge_attr=e_attr_t)
            data_s_list.append(data_s)
            data_t_list.append(data_t)

            family_count[fam_idx] += 1
            questions.append(input_questions['questions'][i])
            programs.append(input_questions['programs'][i])
            answers.append(input_questions['answers'][i])
            question_families.append(input_questions['question_families'][i])
            orig_idxs.append(input_questions['orig_idxs'][i])
            img_idxs.append(img_idx)

            logger.info(f"\nCount = {family_count.sum()}\n")

        if family_count.sum() >= max_sample:
            break

    logger.debug(
        f"Total samples skipped (due to errors/exceptions) = {num_skipped}")
    # ---------------------------------------------------------------------------#
    ## SAVE .H5
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
    output_file = os.path.join(args.output_dir, filename)
    out_dir = args.output_dir
    out_f_prefix = filename.split('.')[0]
    out_fpp = f"{out_dir}/{out_f_prefix}"
    logger.debug(f"out_fpp = {out_fpp}")

    print('sampled question family distribution')
    print(family_count)
    print('| saving output file to %s' % output_file)
    with h5py.File(output_file, 'w') as f:
        f.create_dataset('questions',
                         data=np.asarray(questions, dtype=np.int32))
        f.create_dataset('programs', data=np.asarray(programs, dtype=np.int32))
        f.create_dataset('answers', data=np.asarray(answers))
        f.create_dataset('image_idxs', data=np.asarray(img_idxs))
        f.create_dataset('orig_idxs', data=np.asarray(orig_idxs))
        f.create_dataset('question_families',
                         data=np.asarray(question_families))

    ## ------------  SAVE GRAPH DATA ------------ ##
    save_graph_pairdata(out_fpp,
                        data_s_list,
                        data_t_list,
                        is_directed_graph=is_directed_graph)
    logger.info(f"Saved Graph Data in: {out_fpp}_*.[h5|.gpickle|.npz|.pt] ")
    print('| done')
Ejemplo n.º 3
0
def main(args):
    """
    Save nx.graph (Gss, Gts,...) and corresponding torch_geometric.data.PairData
    (via clevr_parse embedder api).
    """
    if (args.input_vocab_json == '') and (args.output_vocab_json == ''):
        logger.info(
            'Must give one of --input_vocab_json or --output_vocab_json')
        return
    graph_parser = clevr_parser.Parser(
        backend='spacy',
        model=args.parser_lm,
        has_spatial=True,
        has_matching=True).get_backend(identifier='spacy')
    embedder = clevr_parser.Embedder(
        backend='torch', parser=graph_parser).get_backend(identifier='torch')
    is_directed_graph = args.is_directed_graph  # Parse graphs as nx.MultiDiGraph

    out_dir, out_f_prefix = _get_out_dir_and_file_prefix(args)
    checkpoint_dir = f"{out_dir}/checkpoints"
    utils.mkdirs(checkpoint_dir)

    questions, img_scenes = get_questions_and_parsed_scenes(
        args.input_questions_json, args.input_parsed_img_scenes_json)
    if args.is_debug:
        set_default_level(10)
        questions = questions[:
                              128]  # default BSZ is 64 ensuring enought for batch iter
        logger.debug(
            f"In DEBUG mode, sampling {len(questions)} questions only..")
    # Process Vocab #
    vocab = _process_vocab(args, questions)

    # Encode all questions and programs
    logger.info('Encoding data')
    questions_encoded, programs_encoded, answers, image_idxs = [], [], [], []
    question_families = []
    orig_idxs = []

    # Graphs and Embeddings #
    data_s_list = []  # List [torch_geometric.data.Data]
    data_t_list = []  # List [torch_geometric.data.Data]
    num_samples = 0  # Counter for keeping track of processed samples
    num_skipped = 0  # Counter for tracking num of samples skipped
    for orig_idx, q in enumerate(questions):
        # First See if Gss, Gts are possible to extract.
        # If not (for e.g., some edges cases like plurality, skip data sample
        img_idx = q['image_index']
        img_fn = q['image_filename']
        logger.debug(f"\tProcessing Image - {img_idx}: {img_fn} ...")
        # q_idx = q['question_index']
        # q_fam_idx = q['question_family_index']
        ## 1: Ensure both Gs,Gt is parseable for this question sample, o.w. skip
        img_scene = list(
            filter(lambda x: x['image_index'] == img_idx, img_scenes))[0]
        try:
            Gt, t_doc = graph_parser.get_doc_from_img_scene(
                img_scene, is_directed_graph=is_directed_graph)
            X_t, ei_t, e_attr_t = embedder.embed_t(
                img_idx, args.input_parsed_img_scenes_json)
        except AssertionError as ae:
            logger.warning(f"AssertionError Encountered: {ae}")
            logger.warning(f"[{img_fn}] Excluding images with > 10 objects")
            num_skipped += 1
            continue
        if Gt is None and ("SKIP" in t_doc):
            # If the derendering pipeline failed, then just skip the
            # scene, don't process the labels (and text_scenes) for the image
            print(f"Got None img_doc at image_index: {img_idx}")
            print(f"Skipping all text_scenes for imgage idx: {img_idx}")
            num_skipped += 1
            continue
        s = q['question']
        orig_idx = q['question_index']
        try:
            Gs, s_doc = graph_parser.parse(s,
                                           return_doc=True,
                                           is_directed_graph=is_directed_graph)
            X_s, ei_s, e_attr_s = embedder.embed_s(s)
        except ValueError as ve:
            logger.warning(f"ValueError Encountered: {ve}")
            logger.warning(f"Skipping question: {s} for {img_fn}")
            num_skipped += 1
            continue
        if Gs is None and ("SKIP" in s_doc):
            logger.warning(
                "Got None as Gs and 'SKIP' in Gs_embd. (likely plural with CLEVR_OBJS label) "
            )
            logger.warning(
                f"SKIPPING processing {s} for {img_fn} and at {img_idx}")
            num_skipped += 1
            continue

        # Using ClevrData allows us a debug extension to Data
        data_s = ClevrData(x=X_s, edge_index=ei_s, edge_attr=e_attr_s)
        data_t = ClevrData(x=X_t, edge_index=ei_t, edge_attr=e_attr_t)
        data_s_list.append(data_s)
        data_t_list.append(data_t)

        question = q['question']
        orig_idxs.append(orig_idx)
        image_idxs.append(img_idx)
        if 'question_family_index' in q:
            question_families.append(q['question_family_index'])
        question_tokens = preprocess_utils.tokenize(question,
                                                    punct_to_keep=[';', ','],
                                                    punct_to_remove=['?', '.'])
        question_encoded = preprocess_utils.encode(
            question_tokens,
            vocab['question_token_to_idx'],
            allow_unk=args.encode_unk == 1)
        questions_encoded.append(question_encoded)

        has_prog_seq = 'program' in q
        if has_prog_seq:
            program = q['program']
            program_str = program_to_str(program, args.mode)
            program_tokens = preprocess_utils.tokenize(program_str)
            program_encoded = preprocess_utils.encode(
                program_tokens, vocab['program_token_to_idx'])
            programs_encoded.append(program_encoded)

        if 'answer' in q:
            ans = q['answer']
            answers.append(vocab['answer_token_to_idx'][ans])

        num_samples += 1
        logger.info("-" * 50)
        logger.info(f"Samples processed count = {num_samples}")
        if has_prog_seq:
            logger.info(f"\n[{orig_idx}]: question: {question} \n"
                        f"\tprog_str: {program_str} \n"
                        f"\tanswer: {ans}")
        logger.info("-" * 50)

        # ---- CHECKPOINT ---- #
        if num_samples % args.checkpoint_every == 0:
            logger.info(f"Checkpointing at {num_samples}")
            checkpoint_fn_prefix = f"{out_f_prefix}_{num_samples}"
            _out_dir = f"{checkpoint_dir}/{out_f_prefix}_{num_samples}"
            utils.mkdirs(_out_dir)
            out_fpp = f"{_out_dir}/{checkpoint_fn_prefix}"
            # ------------ Checkpoint .H5 ------------#
            logger.info(
                f"CHECKPOINT: Saving checkpoint files at directory: {out_fpp}")
            save_h5(f"{out_fpp}.h5", vocab, questions_encoded, image_idxs,
                    orig_idxs, programs_encoded, question_families, answers)
            # ------------ Checkpoint GRAPH DATA ------------#
            save_graph_pairdata(out_fpp,
                                data_s_list,
                                data_t_list,
                                is_directed_graph=is_directed_graph)
            logger.info(f"-------------- CHECKPOINT: COMPLETED --------")

        if (args.max_sample > 0) and (num_samples >= args.max_sample):
            logger.info(f"len(questions_encoded = {len(questions_encoded)}")
            logger.info("args.max_sample reached: Completing ... ")
            break

    logger.debug(f"Total samples skipped = {num_skipped}")
    logger.debug(f"Total samples processed = {num_samples}")
    out_fpp = f"{out_dir}/{out_f_prefix}"
    ## SAVE .H5: Baseline {dataset}_h5.h5 file (q,p,ans,img_idx) as usual
    logger.info(f"Saving baseline (processed) data in: {out_fpp}.h5")
    save_h5(f"{out_fpp}.h5", vocab, questions_encoded, image_idxs, orig_idxs,
            programs_encoded, question_families, answers)
    ## ------------  SAVE GRAPH DATA ------------ ##
    ## N.b. Ensure the len of theses lists are all equals
    save_graph_pairdata(out_fpp,
                        data_s_list,
                        data_t_list,
                        is_directed_graph=is_directed_graph)
    logger.info(f"Saved Graph Data in: {out_fpp}_*.[h5|.gpickle|.npz|.pt] ")
Ejemplo n.º 4
0
import os, sys
import json
from typing import *

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
logging.getLogger('imported_module').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)

from rsmlkit.logging import get_logger, set_default_level
logger = get_logger(__file__)
set_default_level(logging.DEBUG)

_dir = os.getcwd()
if _dir not in sys.path:
    sys.path.insert(0, _dir)
    # sys.path.append(os.path.abspath("."))
from options.test_options import TestOptions
from datasets import get_dataloader
from executors import get_executor
from models.parser import Seq2seqParser
import utils.utils as utils

import wandb
import clevr_parser

graph_parser = clevr_parser.Parser(
    backend='spacy',
    model='en_core_web_sm',
    has_spatial=True,
    has_matching=True).get_backend(identifier='spacy')