Ejemplo n.º 1
0
def trainingProgressTest(testNo, appConfig, modelArgs, device):
    # Instantiate trainer object.
    modelArgs = deepcopy(modelArgs)
    appConfig = deepcopy(appConfig)
    modelArgs.spotlightThreshold = -1000000  # Pick everything.
    modelArgs.disable_batch_norm = True
    modelArgs.dropout_p = 0
    modelArgs.input_dropout_p = 0
    modelArgs.teacher_forcing_ratio = 0.5
    appConfig.epochs = 10
    trainer = SupervisedTrainer(appConfig, modelArgs, device)

    # Create test data.
    generatorArgs = {
        "node_count_range": (2, 6),
        "max_child_count": 4,
        "tag_gen_params": (50, (2, 5)),
        "attr_gen_params": (50, (2, 5)),
        "attr_value_gen_params": (50, (1, 7)),
        "attr_count_range": (0, 4),
        "text_len_range": (-10, 7),
        "tail_len_range": (-20, 10),
    }
    sampleCount = 5
    test_data = GeneratedXmlDataset((sampleCount, generatorArgs),
                                    fields=trainer.fields)

    # Run tests.
    trainer.train(test_data)
    import pdb
    pdb.set_trace()
Ejemplo n.º 2
0
def smallDataTest(testNo, appConfig, modelArgs, device):
    # Instantiate trainer object.
    trainer = SupervisedTrainer(deepcopy(appConfig), deepcopy(modelArgs),
                                device)

    # Create test data.
    generatorArgs = {
        "node_count_range": (1, 2),
        "max_child_count": 4,
        "tag_gen_params": (50, (0, 5)),
        "attr_gen_params": (50, (0, 5)),
        "attr_value_gen_params": (50, (0, 20)),
        "attr_count_range": (0, 3),
        "text_len_range": (1, 7),
        "tail_len_range": (-20, 10),
    }
    sampleCount = random.randint(1, 3)
    test_data = GeneratedXmlDataset((sampleCount, generatorArgs),
                                    fields=trainer.fields)

    # Run tests.
    trainer.train(test_data)
Ejemplo n.º 3
0
def hier2hierBatchTest(testNo, appConfig, modelArgs, device):
    # Instantiate trainer object.
    trainer = SupervisedTrainer(deepcopy(appConfig), deepcopy(modelArgs),
                                device)

    # Create test data.
    generatorArgs = {
        "node_count_range": (3, 10),
        "max_child_count": 4,
        "tag_gen_params": (50, (0, 5)),
        "attr_gen_params": (50, (0, 5)),
        "attr_value_gen_params": (50, (0, 20)),
        "attr_count_range": (0, 3),
        "text_len_range": (-4, 7),
        "tail_len_range": (-20, 10),
    }
    sampleCount = 5
    test_data = GeneratedXmlDataset((sampleCount, generatorArgs),
                                    fields=trainer.fields)

    # Run tests
    batchDataUnitTest(trainer, test_data)
Ejemplo n.º 4
0
os.makedirs(runFolder, exist_ok=True)
logging.basicConfig(
        filename=runFolder + "evaluation.log",
        format=LOG_FORMAT,
        level=getattr(logging, appConfig.log_level.upper()))


# Log config info.
logging.info("Application Config: {0}".format(json.dumps(vars(appConfig), indent=2)))
logging.info("Unprocessed Model Arguments: {0}".format(json.dumps(modelArgs, indent=2)))

# Pick the device, preferably GPU where we run our application.
device = torch.device("cuda") if torch.cuda.is_available() else None

# Trainer object encapsulates the model and helps test/train/evaluate the model.
trainer = SupervisedTrainer(appConfig, modelArgs, device)
trainer.load()

# Load test dataset.
test_dataset = Hier2HierDataset(baseFolder=appConfig.test_path, fields=trainer.fields, selectPercent=appConfig.input_select_percent)

# Get model from the trainer and put it into eval mode.
h2hModel = trainer.model
h2hModel.eval()

# Batching test inputs.
test_batch_iterator = Hier2HierIterator(
    preprocess_batch=h2hModel.preprocess_batch,
    dataset=test_dataset, batch_size=appConfig.batch_size,
    sort=False, shuffle=True, sort_within_batch=False,
    device=device, repeat=False)
Ejemplo n.º 5
0
# Setup logging
LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
runFolder = appConfig.training_dir + appConfig.runFolder
os.makedirs(runFolder, exist_ok=True)
logging.basicConfig(filename=runFolder + "training.log",
                    format=LOG_FORMAT,
                    level=getattr(logging, appConfig.log_level.upper()))

# Log config info.
logging.info("Application Config: {0}".format(
    json.dumps(vars(appConfig), indent=2)))
logging.info("Unprocessed Model Arguments: {0}".format(
    json.dumps(modelArgs, indent=2)))

# Pick the device, preferably GPU where we run our application.
device = torch.device("cuda") if torch.cuda.is_available() else None

# Trainer object is requred to
trainer = SupervisedTrainer(appConfig, modelArgs, device)

# Load training and dev dataset.
training_data = Hier2HierDataset(baseFolder=appConfig.train_path,
                                 fields=trainer.fields,
                                 selectPercent=appConfig.input_select_percent)
dev_data = Hier2HierDataset(baseFolder=appConfig.dev_path,
                            fields=trainer.fields,
                            selectPercent=appConfig.input_select_percent)

# Train the model.
trainer.train(training_data, dev_data=dev_data)
Ejemplo n.º 6
0
def knownSpotlightTest(testNo, appConfig, modelArgs, device):
    """
    In this test, we train over toy1(node.text reversal). We use one hot encoded attention which matches
    directly with source symbol. This should test if rest of the system is working when we use hard coded attention.
    """
    # Instantiate trainer object.
    modelArgs = deepcopy(modelArgs)
    appConfig = deepcopy(appConfig)
    appConfig.train_path = appConfig.inputs_root_dir + "toy1/dev/"
    appConfig.dev_path = appConfig.inputs_root_dir + "toy1/dev/"
    modelArgs.disable_batch_norm = True
    modelArgs.dropout_p = 0
    modelArgs.input_dropout_p = 0

    appConfig.epochs = 1000
    modelArgs.learning_rate = 0.005

    def spotlightByFormula(hier2hierBatch, sampleIndexLimit, outputIndex):
        """
        For the node.text reversal dataset, this function computes the spotlight to use 
        by finding the right input symbol to focus on.
        """
        retval = []
        # Text positions to NDTLP.
        isTail = False
        for tdol in range(sampleIndexLimit):
            targetOutput = hier2hierBatch.targetOutputsByTdol[tdol]
            targetOutputLength = int(
                hier2hierBatch.targetOutputLengthsByTdol[tdol])

            # Get ndfo position of the input node.
            toi = int(hier2hierBatch.tdol2Toi[tdol])
            rootNode = hier2hierBatch.inputs[toi].getroot()
            ndfo = hier2hierBatch.node2Ndfo[rootNode]

            # Compute gndtol of the input XML node.
            # Computed as follows: inputs[toi] -> getroot() -> node2Ndfo -> ndfo2Gni -> gni2Gndtol.
            nodeGni = hier2hierBatch.ndfo2Gni[ndfo]
            nodeGndtol = hier2hierBatch.gni2Gndtol[nodeGni]

            if outputIndex < 1 + len("<toyrev>"):
                # We are in the opening tag portion of the output.
                # Return gndtol of the XML node.
                spotIndex = nodeGndtol
            else:
                # Compute input and output char index within node.text, where the current output pointer is.
                outputStrIndex = outputIndex - len(
                    "<toyrev>") - 1  # Don't forget <sos>.
                outputStrLength = targetOutputLength - len("<toyrev>") - len(
                    "</toyrev>") - 2  # Don't forget <sos> and <eos>.
                inputStrIndex = outputStrLength - outputStrIndex - 1

                if inputStrIndex < 0:
                    # We are currently in the closing tag portion of the output.
                    # Return gndtol of the XML node.
                    spotIndex = nodeGndtol
                else:
                    # Get the raw char stored at input.
                    ch = hier2hierBatch.inputs[toi].getroot(
                    ).text[inputStrIndex]

                    # Make sure that input is encoded correctly.
                    ndtx2 = hier2hierBatch.ndfo2Ndtl2[isTail][ndfo]
                    ndtxp2 = hier2hierBatch.ndtxTuple2Ndtlp2[isTail][(
                        ndtx2, inputStrIndex)]
                    encodedInputSymbol = int(
                        hier2hierBatch.encodedTextByTtDLP.data[ndtxp2])
                    inputVocab = hier2hierBatch.torchBatch.dataset.fields[
                        "src"].vocabs.text
                    assert (ch == inputVocab.itos[encodedInputSymbol])

                    # Compute spot index from gndtol of output character position.
                    charGni = hier2hierBatch.ndttp2Gni[ndtxp2]
                    charGndtol = hier2hierBatch.gni2Gndtol[charGni]
                    spotIndex = charGndtol

                    # Make sure that target output is encoded correctly.
                    targetOutput = hier2hierBatch.targetOutputsByTdol[tdol]
                    outputVocab = hier2hierBatch.torchBatch.dataset.fields[
                        "tgt"].vocab
                    assert (ch == outputVocab.itos[targetOutput[outputIndex]])

            retval.append(spotIndex)
        return torch.LongTensor(retval)

    # Build trainer with a model which will use formula for spotlight.
    trainer = SupervisedTrainer(appConfig, modelArgs, device,
                                spotlightByFormula)

    # Load training and dev dataset.
    training_data = Hier2HierDataset(
        baseFolder=appConfig.train_path,
        fields=trainer.fields,
        selectPercent=appConfig.input_select_percent)
    dev_data = Hier2HierDataset(baseFolder=appConfig.dev_path,
                                fields=trainer.fields,
                                selectPercent=appConfig.input_select_percent)

    # Train the model.
    trainer.train(training_data, dev_data=dev_data)

    # Train once more.
    import pdb
    pdb.set_trace()
    appConfig.epochs = 2000
    trainer.model.outputDecoder.spotlightByFormula = None
    trainer.train(training_data, dev_data=dev_data)
Ejemplo n.º 7
0
def noInteractionTest(testNo, appConfig, modelArgs, device):
    # Instantiate trainer object.
    modelArgs = deepcopy(modelArgs)
    appConfig = deepcopy(appConfig)
    modelArgs.spotlightThreshold = -1000000  # Pick everything.
    modelArgs.disable_batch_norm = True
    modelArgs.dropout_p = 0
    modelArgs.input_dropout_p = 0
    modelArgs.teacher_forcing_ratio = 0
    trainer = SupervisedTrainer(appConfig, modelArgs, device)

    # Create test data.
    generatorArgs = {
        "node_count_range": (5, 10),
        "max_child_count": 4,
        "tag_gen_params": (50, (2, 5)),
        "attr_gen_params": (50, (0, 4)),
        "attr_value_gen_params": (50, (1, 20)),
        "attr_len_range": (2, 5),
        "text_len_range": (-10, 7),
        "tail_len_range": (-20, 10),
    }
    sampleCount = 50
    test_data = GeneratedXmlDataset((sampleCount, generatorArgs),
                                    fields=trainer.fields)
    trainer.load(test_data)
    examples = test_data.examples

    exampleToWatch = examples[0]
    remainingExamples = examples[1:]
    resultsToWatch = []
    prevResult = None
    prevGraphNodeCount = None
    for trial in range(20):
        chosenExamples = random.sample(remainingExamples, 5)
        watchPosition = random.randint(0, len(chosenExamples))
        print("Selected watch position ", watchPosition)
        chosenExamples.insert(watchPosition, exampleToWatch)
        assert (chosenExamples[watchPosition] == exampleToWatch)
        test_data_section = GeneratedXmlDataset(chosenExamples,
                                                fields=trainer.fields)

        batch_iterator = Hier2HierIterator(
            preprocess_batch=trainer.model.preprocess_batch,
            dataset=test_data_section,
            batch_size=len(test_data_section),
            sort=False,
            shuffle=False,
            sort_within_batch=False,
            device=device,
            repeat=False,
        )
        batch_generator = batch_iterator.__iter__(mode=appConfig.mode)
        test_data_batch = list(batch_generator)[0]

        dataDebugStages = []
        dataDebugHook = createDataDebugHook(test_data_batch, dataDebugStages)
        trainer.model(test_data_batch, dataDebugHook=dataDebugHook)

        curGraphNodeCount = len([
            gni for gni, toi in enumerate(test_data_batch.gni2Toi)
            if toi == watchPosition
        ])
        if prevGraphNodeCount is not None:
            assert (curGraphNodeCount == prevGraphNodeCount)
        prevGraphNodeCount = curGraphNodeCount

        curResult = [(stageName, dataDebugStage[watchPosition])
                     for (stageName, dataDebugStage) in dataDebugStages]
        outputLenToWatch = test_data_batch.targetOutputLengthsByToi[
            watchPosition]

        # Remove computation stages which are not relevant for exampleToWatch at watchPosition.
        _curResult = []
        for result in curResult:
            atPos = result[0].find("@")
            if atPos >= 0:
                try:
                    charIndex = int(result[0][0:atPos])
                    if charIndex >= outputLenToWatch:
                        continue
                except ValueError:
                    pass
            _curResult.append(result)
        curResult = _curResult

        if prevResult is not None:
            assert (len(curResult) == len(prevResult))
            stageCount = len(curResult)
            for stage in range(stageCount):
                assert (
                    curResult[stage][1].shape == prevResult[stage][1].shape)
                diff = curResult[stage][1] - prevResult[stage][1]
                diffNorm = torch.norm(diff)
                print("Trial {0}. Stage {1}:{2}. Diff {3}".format(
                    trial,
                    stage,
                    (curResult[stage][0] if curResult[stage][0]
                     == prevResult[stage][0] else curResult[stage][0] + "/" +
                     prevResult[stage][0]),
                    diffNorm,
                ))
                assert (diffNorm < 1e-5)
        prevResult = curResult