示例#1
0
    def __init__(self,
                 path_to_data: str,
                 save_directory: str,
                 k: int,
                 split="train",
                 input_vocabulary_file="",
                 target_vocabulary_file="",
                 generate_vocabulary=False):
        assert os.path.exists(
            path_to_data
        ), "Trying to read a gSCAN dataset from a non-existing file {}.".format(
            path_to_data)
        if not generate_vocabulary:
            assert os.path.exists(os.path.join(save_directory, input_vocabulary_file)) and os.path.exists(
                os.path.join(save_directory, target_vocabulary_file)), \
                "Trying to load vocabularies from non-existing files."
        if split == "test" and generate_vocabulary:
            logger.warning(
                "WARNING: generating a vocabulary from the test set.")
        self.dataset = GroundedScan.load_dataset_from_file(
            path_to_data, save_directory=save_directory, k=k)
        if self.dataset._data_statistics.get("adverb_1"):
            logger.info("Verb-adverb combinations in training set: ")
            for adverb, items in self.dataset._data_statistics["train"][
                    "verb_adverb_combinations"].items():
                logger.info("Verbs for adverb: {}".format(adverb))
                for key, count in items.items():
                    logger.info("   {}: {} occurrences.".format(key, count))
            logger.info("Verb-adverb combinations in dev set: ")
            for adverb, items in self.dataset._data_statistics["dev"][
                    "verb_adverb_combinations"].items():
                logger.info("Verbs for adverb: {}".format(adverb))
                for key, count in items.items():
                    logger.info("   {}: {} occurrences.".format(key, count))
        self.image_dimensions = self.dataset.situation_image_dimension
        self.image_channels = 3
        self.split = split
        self.directory = save_directory

        # Keeping track of data.
        self._examples = np.array([])
        self._input_lengths = np.array([])
        self._target_lengths = np.array([])
        if generate_vocabulary:
            logger.info("Generating vocabularies...")
            self.input_vocabulary = Vocabulary()
            self.target_vocabulary = Vocabulary()
            self.read_vocabularies()
            logger.info("Done generating vocabularies.")
        else:
            logger.info("Loading vocabularies...")
            self.input_vocabulary = Vocabulary.load(
                os.path.join(save_directory, input_vocabulary_file))
            self.target_vocabulary = Vocabulary.load(
                os.path.join(save_directory, target_vocabulary_file))
            logger.info("Done loading vocabularies.")
示例#2
0
    def __init__(self,
                 path_to_data: str,
                 save_directory: str,
                 split="train",
                 input_vocabulary_file="",
                 target_vocabulary_file="",
                 generate_vocabulary=False):
        assert os.path.exists(
            path_to_data
        ), "Trying to read a gSCAN dataset from a non-existing file {}.".format(
            path_to_data)
        if not generate_vocabulary:
            assert os.path.exists(os.path.join(save_directory, input_vocabulary_file)) and os.path.exists(
                os.path.join(save_directory, target_vocabulary_file)), \
                "Trying to load vocabularies from non-existing files."
        if split == "test" and generate_vocabulary:
            logger.warning(
                "WARNING: generating a vocabulary from the test set.")
        self.dataset = GroundedScan.load_dataset_from_file(
            path_to_data, save_directory=save_directory)
        self.image_dimensions = self.dataset.situation_image_dimension
        self.image_channels = 3
        self.split = split
        self.directory = save_directory

        # Keeping track of data.
        self._examples = np.array([])
        self._input_lengths = np.array([])
        self._target_lengths = np.array([])
        if generate_vocabulary:
            logger.info("Generating vocabularies...")
            self.input_vocabulary = Vocabulary()
            self.target_vocabulary = Vocabulary()
            self.read_vocabularies()
            logger.info("Done generating vocabularies.")
        else:
            logger.info("Loading vocabularies...")
            self.input_vocabulary = Vocabulary.load(
                os.path.join(save_directory, input_vocabulary_file))
            self.target_vocabulary = Vocabulary.load(
                os.path.join(save_directory, target_vocabulary_file))
            logger.info("Done loading vocabularies.")
示例#3
0
def main():
    parser = argparse.ArgumentParser(description="Grounded SCAN")

    # General arguments.
    parser.add_argument(
        '--mode',
        type=str,
        default='execute_commands',
        help=
        'Generate (mode=generate) data, run tests (mode=test), analyse the end positions of'
        ' predictions compared to ground truth end positions (mode=position_analysis) '
        ' or execute commands from a file (mode=execute_commands).')
    parser.add_argument('--load_dataset_from',
                        type=str,
                        default='',
                        help='Path to file with dataset.')
    parser.add_argument('--output_directory',
                        type=str,
                        default='output',
                        help='Path to a folder in which '
                        'all outputs should be '
                        'stored.')
    parser.add_argument('--predicted_commands_files',
                        type=str,
                        default='predict.json',
                        help='Path to a file with predictions.')
    parser.add_argument('--save_dataset_as',
                        type=str,
                        default='dataset.txt',
                        help='Filename to save dataset in.')
    parser.add_argument(
        "--count_equivalent_examples",
        dest="count_equivalent_examples",
        default=False,
        action="store_true",
        help="Whether or not to count the number of equivalent examples in the "
        "training and test set at the end of generation.")
    parser.add_argument(
        "--only_save_errors",
        dest="only_save_errors",
        default=False,
        action="store_true",
        help="If mode=execute_commands, whether to only save the errors.")
    parser.add_argument("--make_dev_set",
                        dest="make_dev_set",
                        default=False,
                        action="store_true")

    # Dataset arguments.
    parser.add_argument('--max_examples',
                        type=int,
                        default=None,
                        help="Max. number of examples to generate.")
    parser.add_argument(
        '--split',
        type=str,
        default='generalization',
        choices=['uniform', 'generalization', "target_lengths"])
    parser.add_argument('--k_shot_generalization',
                        type=int,
                        default=0,
                        help="Number of examples of a particular "
                        "split to add to the training set.")
    parser.add_argument('--num_resampling',
                        type=int,
                        default=10,
                        help='Number of time to resample a semantically '
                        'equivalent situation (which will likely result'
                        ' in different situations in terms of object '
                        'locations).')
    parser.add_argument('--visualize_per_template',
                        type=int,
                        default=0,
                        help='How many visualization to generate per '
                        'command template.')
    parser.add_argument('--visualize_per_split',
                        type=int,
                        default=0,
                        help='How many visualization to generate per '
                        'test split.')
    parser.add_argument(
        '--percentage_train',
        type=float,
        default=.7,
        help=
        'Percentage of examples to put in the training set (rest is test set).'
    )
    parser.add_argument(
        '--percentage_dev',
        type=float,
        default=.05,
        help=
        'Percentage of examples to put in the training set (rest is test set).'
    )
    parser.add_argument('--cut_off_target_length',
                        type=int,
                        default=None,
                        help="Examples of what target length to put"
                        " in the test set for "
                        "--split=target_lengths")

    # World arguments.
    parser.add_argument('--grid_size',
                        type=int,
                        default=6,
                        help='Number of rows (and columns) in the grid world.')
    parser.add_argument('--min_other_objects',
                        type=int,
                        default=0,
                        help='Minimum amount of objects to put in the grid '
                        'world.')  # TODO: being used?
    parser.add_argument('--max_objects',
                        type=int,
                        default=2,
                        help='Maximum amount of objects to put in the grid '
                        'world.')  # TODO: being used?
    parser.add_argument('--min_object_size',
                        type=int,
                        default=1,
                        help='Smallest object size.')  # TODO: remove these?
    parser.add_argument('--max_object_size',
                        type=int,
                        default=4,
                        help='Biggest object size.')  # TODO: remove these?
    parser.add_argument(
        '--other_objects_sample_percentage',
        type=float,
        default=.5,
        help=
        'Percentage of possible objects distinct from the target to place in the world.'
    )

    # Grammar and Vocabulary arguments
    parser.add_argument(
        '--type_grammar',
        type=str,
        default='adverb',
        choices=['simple_intrans', 'simple_trans', 'normal', 'adverb', 'full'])
    parser.add_argument('--intransitive_verbs',
                        type=str,
                        default='walk',
                        help='Comma-separated list of '
                        'intransitive verbs.')
    parser.add_argument('--transitive_verbs',
                        type=str,
                        default='pull,push',
                        help='Comma-separated list of '
                        'transitive verbs.')
    parser.add_argument(
        '--adverbs',
        type=str,
        default='cautiously,while spinning,hesitantly,while zigzagging',
        help='Comma-separated list of adverbs.')
    parser.add_argument('--nouns',
                        type=str,
                        default='square,cylinder,circle',
                        help='Comma-separated list of nouns.')
    parser.add_argument('--color_adjectives',
                        type=str,
                        default='red,green,yellow,blue',
                        help='Comma-separated list of '
                        'colors.')
    parser.add_argument('--size_adjectives',
                        type=str,
                        default='big,small',
                        help='Comma-separated list of sizes.')
    parser.add_argument(
        '--sample_vocabulary',
        type=str,
        default='default',
        choices=['default', 'sample'],
        help="Whether to specify own vocabulary or to sample a nonsensical one."
    )

    # Only relevant when --sample_vocabulary='sample'
    parser.add_argument('--num_intransitive_verbs',
                        type=int,
                        default=1,
                        help='number of intransitive verbs to sample.')
    parser.add_argument('--num_transitive_verbs',
                        type=int,
                        default=1,
                        help='number of transitive verbs to sample.')
    parser.add_argument('--num_adverbs',
                        type=int,
                        default=6,
                        help='number of adverbs to sample.')
    parser.add_argument('--num_nouns',
                        type=int,
                        default=3,
                        help='number of nouns to sample.')
    parser.add_argument('--num_color_adjectives',
                        type=int,
                        default=2,
                        help='number of color adjectives to sample.')
    parser.add_argument('--num_size_adjectives',
                        type=int,
                        default=2,
                        help='number of size adjectives to sample.')

    flags = vars(parser.parse_args())

    if flags["type_grammar"] == "full":
        raise NotImplementedError(
            "Full type grammar (with conjunctions) not implemented (yet).")

    if flags['mode'] == 'execute_commands' or flags['mode'] == 'error_analysis':
        assert os.path.exists(flags['load_dataset_from']), \
            "if mode={}, please specify data location in --load_dataset_from".format(flags['mode'])

    if flags["split"] == "target_lengths":
        assert flags[
            "cut_off_target_length"], "Specify --cut_off_target_length if --split=target_lengths."

    # Create directory for visualizations if it doesn't exist.
    if flags['output_directory']:
        visualization_path = os.path.join(os.getcwd(),
                                          flags['output_directory'])
        if not os.path.exists(visualization_path):
            os.mkdir(visualization_path)

    if flags['mode'] == 'generate':
        intransitive_verbs = flags["intransitive_verbs"].split(',') \
            if flags["sample_vocabulary"] != 'sample' else flags["num_intransitive_verbs"]
        transitive_verbs = flags["transitive_verbs"].split(',') \
            if flags["sample_vocabulary"] != 'sample' else flags["num_transitive_verbs"]
        adverbs = flags["adverbs"].split(
            ','
        ) if flags["sample_vocabulary"] != 'sample' else flags["num_adverbs"]
        nouns = flags["nouns"].split(
            ','
        ) if flags["sample_vocabulary"] != 'sample' else flags["num_nouns"]
        if flags["sample_vocabulary"] != 'sample':
            # Special case when no color or size adjectives specified.
            color_adjectives = flags["color_adjectives"].split(
                ',') if flags["color_adjectives"] else []
            size_adjectives = flags["size_adjectives"].split(
                ',') if flags["size_adjectives"] else []
        else:
            color_adjectives = flags["num_color_adjectives"]
            size_adjectives = flags["num_size_adjectives"]

        # Sample a vocabulary and a grammar with rules of form NT -> T and T -> {words from vocab}.
        grounded_scan = GroundedScan(
            intransitive_verbs=intransitive_verbs,
            transitive_verbs=transitive_verbs,
            adverbs=adverbs,
            nouns=nouns,
            color_adjectives=color_adjectives,
            size_adjectives=size_adjectives,
            min_object_size=flags["min_object_size"],
            max_object_size=flags["max_object_size"],
            percentage_train=flags["percentage_train"],
            percentage_dev=flags["percentage_dev"],
            sample_vocabulary=flags["sample_vocabulary"],
            save_directory=flags["output_directory"],
            grid_size=flags["grid_size"],
            type_grammar=flags["type_grammar"])

        # Generate all possible commands from the grammar and pair them with relevant situations.
        grounded_scan.get_data_pairs(
            max_examples=flags["max_examples"],
            num_resampling=flags['num_resampling'],
            other_objects_sample_percentage=flags[
                'other_objects_sample_percentage'],
            visualize_per_template=flags['visualize_per_template'],
            visualize_per_split=flags['visualize_per_split'],
            split_type=flags["split"],
            train_percentage=flags['percentage_train'],
            min_other_objects=flags['min_other_objects'],
            k_shot_generalization=flags['k_shot_generalization'],
            make_dev_set=flags["make_dev_set"],
            cut_off_target_length=flags["cut_off_target_length"])
        logger.info("Gathering dataset statistics...")
        grounded_scan.save_dataset_statistics(split="train")
        if flags["split"] == "uniform" or flags["split"] == "target_lengths":
            if flags["make_dev_set"]:
                grounded_scan.save_dataset_statistics(split="dev")
            grounded_scan.save_dataset_statistics(split="test")
            if flags["split"] == "target_lengths":
                grounded_scan.save_dataset_statistics(split="target_lengths")
        elif flags["split"] == "generalization":
            splits = [
                "test", "visual", "situational_1", "situational_2",
                "contextual", "adverb_1", "adverb_2", "visual_easier"
            ]
            if flags["make_dev_set"]:
                splits += ["dev"]
            for split in splits:
                grounded_scan.save_dataset_statistics(split=split)
        dataset_path = grounded_scan.save_dataset(flags['save_dataset_as'])
        grounded_scan.visualize_data_examples()
        logger.info("Saved dataset to {}".format(dataset_path))
        if flags['count_equivalent_examples']:
            if flags["split"] == "uniform":
                splits_to_count = ["test"]
            elif flags["split"] == "generalization":
                splits_to_count = [
                    "visual", "situational_1", "situational_2", "contextual"
                ]
            else:
                raise ValueError("Unknown option for flag --split: {}".format(
                    flags["split"]))
            for split in splits_to_count:
                logger.info(
                    "Equivalent examples in train and testset: {}".format(
                        grounded_scan.count_equivalent_examples(
                            "train", split)))
    elif flags['mode'] == 'execute_commands':
        files = flags["predicted_commands_files"].split(",")
        for file in files:
            logger.info(
                "Performing error analysis on file with predictions: {}".
                format(file))
            grounded_scan = GroundedScan.load_dataset_from_file(
                flags["load_dataset_from"], flags["output_directory"])
            grounded_scan.visualize_prediction(
                os.path.join(flags["output_directory"], file),
                only_save_errors=flags["only_save_errors"])
            logger.info("Saved visualizations in directory: {}.".format(
                flags["output_directory"]))
    elif flags['mode'] == 'position_analysis':
        files = flags["predicted_commands_files"].split(",")
        workbook = Workbook()
        for file in files:
            logger.info(
                "Performing position analysis on file with predictions: {}".
                format(file))
            grounded_scan = GroundedScan.load_dataset_from_file(
                flags["load_dataset_from"], flags["output_directory"])
            grounded_scan.position_analysis(os.path.join(
                flags["output_directory"], file),
                                            workbook=workbook)
            logger.info("Wrote position analysis for {}".format(file))
        outfile_excel = os.path.join(flags["output_directory"],
                                     "position_analysis.xls")
        workbook.save(outfile_excel)
        logger.info("Done.")
    elif flags['mode'] == 'test':
        logger.info("Running all tests..")
        run_all_tests()
    elif flags['mode'] == 'error_analysis':
        files = flags["predicted_commands_files"].split(",")
        for file in files:
            file_name = file.split(".json")[0]
            logger.info(
                "Performing error analysis on file with predictions: {}".
                format(file))
            grounded_scan = GroundedScan.load_dataset_from_file(
                flags["load_dataset_from"], flags["output_directory"])
            save_plots_in = os.path.join(flags["output_directory"], file_name)
            if not os.path.exists(save_plots_in):
                os.mkdir(save_plots_in)
            grounded_scan.error_analysis(
                predictions_file=os.path.join(flags["output_directory"], file),
                output_file=os.path.join(save_plots_in, "error_analysis.txt"),
                save_directory=save_plots_in)
            logger.info("Wrote data to path: {}.".format(
                os.path.join(save_plots_in, "error_analysis.txt")))
            logger.info("Saved plots in directory: {}.".format(save_plots_in))
    elif flags['mode'] == 'write_data_statistics':
        grounded_scan = GroundedScan.load_dataset_from_file(
            flags["load_dataset_from"], flags["output_directory"])
        logger.info("Writing statistics to {}".format(
            flags["output_directory"]))
        for split in grounded_scan._possible_splits:
            grounded_scan.save_dataset_statistics(split=split)
    else:
        raise ValueError(
            "Unknown value for command-line argument 'mode'={}.".format(
                flags['mode']))
def train(train_data_path: str, val_data_paths: dict, use_cuda: bool):
    device = torch.device(type='cuda') if use_cuda else torch.device(
        type='cpu')

    logger.info("Loading Training set...")
    train_iter, train_input_vocab, train_target_vocab = dataloader(
        train_data_path, batch_size=cfg.TRAIN.BATCH_SIZE, use_cuda=use_cuda)
    val_iters = {}
    for split_name, path in val_data_paths.items():
        val_iters[split_name], _, _ = dataloader(
            path,
            batch_size=cfg.VAL_BATCH_SIZE,
            use_cuda=use_cuda,
            input_vocab=train_input_vocab,
            target_vocab=train_target_vocab,
            random_shuffle=False)

    pad_idx, sos_idx, eos_idx = train_target_vocab.stoi['<pad>'], train_target_vocab.stoi['<sos>'], \
                                train_target_vocab.stoi['<eos>']

    train_input_vocab_size, train_target_vocab_size = len(
        train_input_vocab.itos), len(train_target_vocab.itos)

    logger.info("Loading Dev. set...")

    val_input_vocab_size, val_target_vocab_size = train_input_vocab_size, train_target_vocab_size
    logger.info("Done Loading Dev. set.")

    model = GSCAN_model(pad_idx,
                        eos_idx,
                        train_input_vocab_size,
                        train_target_vocab_size,
                        is_baseline=False)
    model = model.cuda() if use_cuda else model
    assert os.path.isfile(
        model_file), "No model checkpoint found at {}".format(model_file)
    logger.info(
        "Loading model checkpoint from file at '{}'".format(model_file))
    _ = model.load_model(model_file)

    baseline = GSCAN_model(pad_idx,
                           eos_idx,
                           train_input_vocab_size,
                           train_target_vocab_size,
                           is_baseline=True)
    baseline = baseline.cuda() if use_cuda else baseline
    assert os.path.isfile(
        baseline_file), "No baseline checkpoint found at {}".format(
            baseline_file)
    logger.info(
        "Loading model checkpoint from file at '{}'".format(baseline_file))
    _ = baseline.load_model(baseline_file)

    original_dataset = GroundedScan.load_dataset_from_file(
        "/root/multimodal_seq2seq_gSCAN/data/compositional_splits/dataset.txt",
        save_directory="stat/",
        k=10)
    # original_dataset = dill.load(open('original_dataset.p', 'rb'))

    with torch.no_grad():
        model.eval()
        logger.info("Evaluating..")
        print(val_iters)
        for split_name, val_iter in val_iters.items():
            model_exact_match = exact_match_indicator(
                val_iter,
                model=model,
                max_decoding_steps=30,
                pad_idx=pad_idx,
                sos_idx=sos_idx,
                eos_idx=eos_idx,
                max_examples_to_evaluate=None)
            baseline_exact_match = exact_match_indicator(
                val_iter,
                model=baseline,
                max_decoding_steps=30,
                pad_idx=pad_idx,
                sos_idx=sos_idx,
                eos_idx=eos_idx,
                max_examples_to_evaluate=None)
            model_diff = torch.bitwise_xor(model_exact_match,
                                           baseline_exact_match)
            model_better_exs = torch.bitwise_and(model_diff, model_exact_match)
            # predict_and_write(val_iter, model, model_exact_match, 30, input_vocab=train_input_vocab,
            #                   target_vocab=train_target_vocab, out='model_good/'+split_name + '_predict.json',
            #                   split_name=split_name, original_dataset=original_dataset, max_examples_to_output=20)
            predict_and_write(val_iter,
                              baseline,
                              model_better_exs,
                              30,
                              input_vocab=train_input_vocab,
                              target_vocab=train_target_vocab,
                              out='model_good_bl_fail/' + split_name +
                              '_predict.json',
                              split_name=split_name,
                              original_dataset=original_dataset,
                              max_examples_to_output=200)
            predict_and_write(val_iter,
                              model,
                              ~model_exact_match,
                              30,
                              input_vocab=train_input_vocab,
                              target_vocab=train_target_vocab,
                              out='model_bad/' + split_name + '_predict.json',
                              split_name=split_name,
                              original_dataset=original_dataset,
                              max_examples_to_output=200)