Ejemplo n.º 1
0
def main():

    # set para
    config = getattr(configurations, "get_config_cs2en")()
    logger.info("Model options:\n{}".format(pprint.pformat(config)))
    tr_stream = get_tr_stream(**config)

    # Create Theano variables
    logger.info("Creating theano variables")

    source_sentence0 = tensor.lmatrix("source0")
    source_sentence_mask0 = tensor.matrix("source0_mask")
    target_sentence0 = tensor.lmatrix("target0")
    target_sentence_mask0 = tensor.matrix("target0_mask")

    source_sentence1 = tensor.lmatrix("source1")
    source_sentence_mask1 = tensor.matrix("source1_mask")
    target_sentence1 = tensor.lmatrix("target1")
    target_sentence_mask1 = tensor.matrix("target1_mask")

    source_sentence2 = tensor.lmatrix("source2")
    source_sentence_mask2 = tensor.matrix("source2_mask")
    target_sentence2 = tensor.lmatrix("target2")
    target_sentence_mask2 = tensor.matrix("target2_mask")

    sampling_input0 = tensor.lmatrix("input0")
    sampling_input1 = tensor.lmatrix("input1")
    sampling_input2 = tensor.lmatrix("input2")

    sampling_hstates0 = tensor.fmatrix("hstates0")
    sampling_hstates1 = tensor.fmatrix("hstates1")
    sampling_hstates2 = tensor.fmatrix("hstates2")

    sampling_lastrep0 = tensor.tensor3("lastrep0")
    sampling_lastrep1 = tensor.tensor3("lastrep1")

    hstates = theano.shared(value=numpy.zeros((config["enc_nhids"]), dtype=theano.config.floatX), name="hstates")

    # Get vocab
    sources = get_attr_rec(tr_stream, "data_stream")
    src_vocab = sources.data_streams[0].dataset.dictionary
    trg_vocab = sources.data_streams[1].dataset.dictionary

    # Construct model
    logger.info("Building PoemModel")

    block0 = PoemBlock(config=config, blockid="block0", name="poemblock0")
    block1 = PoemBlock(config=config, blockid="block1", name="poemblock1")
    block2 = PoemBlock(config=config, blockid="block2", name="poemblock2")

    cost0, hsta0, rep0 = block0.cost(
        source_sentence0,
        source_sentence_mask0,
        source_sentence_mask1,
        source_sentence_mask0,
        target_sentence0,
        target_sentence_mask0,
        hstates,
        lastrep0=None,
        lastrep1=None,
    )

    cost1, hsta1, rep1 = block1.cost(
        source_sentence1,
        source_sentence_mask0,
        source_sentence_mask1,
        source_sentence_mask1,
        target_sentence1,
        target_sentence_mask1,
        hsta0,
        lastrep0=rep0,
        lastrep1=None,
    )

    cost2, hsta2, rep2 = block2.cost(
        source_sentence2,
        source_sentence_mask0,
        source_sentence_mask1,
        source_sentence_mask2,
        target_sentence2,
        target_sentence_mask2,
        hsta1,
        lastrep0=rep0,
        lastrep1=rep1,
    )

    cost = cost0 + cost1 + cost2
    cost.name = "total_cost"

    logger.info("Creating computational graph")

    cg = ComputationGraph(cost)

    # Initialize model
    logger.info("Initializing model")
    block0.set_initw(IsotropicGaussian(config["weight_scale"]))
    block0.set_initb(Constant(0))
    block0.push_initialization_config()
    block0.set_specialinit(Orthogonal(), Orthogonal())
    block0.initialize()

    block1.set_initw(IsotropicGaussian(config["weight_scale"]))
    block1.set_initb(Constant(0))
    block1.push_initialization_config()
    block1.set_specialinit(Orthogonal(), Orthogonal())
    block1.initialize()

    block2.set_initw(IsotropicGaussian(config["weight_scale"]))
    block2.set_initb(Constant(0))
    block2.push_initialization_config()
    block2.set_specialinit(Orthogonal(), Orthogonal())
    block2.initialize()

    # apply dropout for regularization
    if config["dropout"] < 1.0:
        # dropout is applied to the output of maxout in ghog
        logger.info("Applying dropout")
        dropout_inputs = [x for x in cg.intermediary_variables if x.name == "maxout_apply_output"]
        cg = apply_dropout(cg, dropout_inputs, config["dropout"])

    # Print shapes

    shapes = [param.get_value().shape for param in cg.parameters]
    logger.info("Parameter shapes: ")
    for shape, count in Counter(shapes).most_common():
        logger.info("    {:15}: {}".format(shape, count))
    logger.info("Total number of parameters: {}".format(len(shapes)))

    # Print parameter names

    param_dict = Selector(block0).get_parameters()
    logger.info("Parameter names: ")
    for name, value in param_dict.items():
        logger.info("    {:15}: {}".format(value.get_value().shape, name))
    logger.info("Total number of parameters: {}".format(len(param_dict)))

    # Set up training model
    logger.info("Building model")
    training_model = Model(cost)

    # logger.info(cg.auxiliary_variables)
    # logger.info("______________________________")

    """
    weights = ""
    for va in cg.auxiliary_variables:
        if va.name == "sequence_generator_block0_cost_matrix_weighted_averages":
            weights = va

    weightsize = weights.shape
    weightsize.name = "weightsize"

    states = ""
    for va in cg.auxiliary_variables:
        if va.name == "sequence_generator_block0_cost_matrix_states":
            states = va

    statesize = states.shape
    statesize.name = "statesize"

    rep = ""
    for va in cg.auxiliary_variables:
        if va.name == "poemblock0_cost_block0hstatesRepeat":
            rep = va

    repsize = rep.shape
    repsize.name = "repsize"

    """

    # Set extensions
    logger.info("Initializing extensions")
    extensions = [
        FinishAfter(after_n_batches=config["finish_after"]),
        TrainingDataMonitoring([cost], after_batch=True),
        Printing(after_batch=True),
        CheckpointNMT(config["saveto"], every_n_batches=config["save_freq"]),
    ]

    # Set up training algorithm
    logger.info("Initializing training algorithm")
    algorithm = GradientDescent(
        cost=cost,
        parameters=cg.parameters,
        step_rule=CompositeRule([StepClipping(config["step_clipping"]), eval(config["step_rule"])()]),
    )

    # Reload model if necessary
    if config["reload"]:
        extensions.append(LoadNMT(config["saveto"]))

    # Add sampling

    if config["hook_samples"] >= 1:
        logger.info("Building sampler")

        generated0 = block0.mygenerate(sampling_input0, sampling_hstates0)
        search_model0 = Model(generated0)

        generated1 = block1.mygenerate(sampling_input1, sampling_hstates1, sampling_lastrep0)
        search_model1 = Model(generated1)

        generated2 = block2.mygenerate(sampling_input2, sampling_hstates2, sampling_lastrep0, sampling_lastrep1)
        search_model2 = Model(generated2)

        extensions.append(
            Sampler(
                config=config,
                model0=search_model0,
                model1=search_model1,
                model2=search_model2,
                data_stream=tr_stream,
                hook_samples=config["hook_samples"],
                every_n_batches=config["sampling_freq"],
                src_vocab_size=config["src_vocab_size"],
            )
        )

        logger.info("End of building sampler")

    # Initialize main loop
    logger.info("Initializing main loop")
    main_loop = MainLoop(model=training_model, algorithm=algorithm, data_stream=tr_stream, extensions=extensions)

    # Train!
    main_loop.run()