Beispiel #1
0
    def __init__(self, config, vocab_size):
        context = tensor.imatrix('context')
        context_mask = tensor.imatrix('context_mask')
        answer = tensor.imatrix('answer')
        answer_mask = tensor.imatrix('answer_mask')

        bricks = []

        context = context.dimshuffle(1, 0)
        context_mask = context_mask.dimshuffle(1, 0)
        answer = answer.dimshuffle(1, 0)
        answer_mask = answer_mask.dimshuffle(1, 0)

        context_bag = to_bag(context, vocab_size)

        # Embed questions and context
        embed = LookupTable(vocab_size, config.embed_size, name='embed')
        embed.weights_init = IsotropicGaussian(0.01)
        #embeddings_initial_value = init_embedding_table(filename='embeddings/vocab_embeddings.txt')
        #embed.weights_init = Constant(embeddings_initial_value)

        # Calculate context encoding (concatenate layer1)
        cembed = embed.apply(context)
        clstms, chidden_list = make_bidir_lstm_stack(
            cembed, config.embed_size,
            context_mask.astype(theano.config.floatX), config.ctx_lstm_size,
            config.ctx_skip_connections, 'ctx')
        bricks = bricks + clstms
        if config.ctx_skip_connections:
            cenc_dim = 2 * sum(config.ctx_lstm_size)  #2 : fw & bw
            cenc = tensor.concatenate(chidden_list, axis=2)
        else:
            cenc_dim = 2 * config.ctx_lstm_size[-1]
            cenc = tensor.concatenate(chidden_list[-2:], axis=2)
        cenc.name = 'cenc'

        # Build the encoder bricks
        transition = GatedRecurrent(activation=Tanh(),
                                    dim=config.generator_lstm_size,
                                    name="transition")
        attention = SequenceContentAttention(
            state_names=transition.apply.states,
            attended_dim=cenc_dim,
            match_dim=config.generator_lstm_size,
            name="attention")
        readout = Readout(readout_dim=vocab_size,
                          source_names=[
                              transition.apply.states[0],
                              attention.take_glimpses.outputs[0]
                          ],
                          emitter=MaskedSoftmaxEmitter(context_bag=context_bag,
                                                       name='emitter'),
                          feedback_brick=LookupFeedback(
                              vocab_size, config.feedback_size),
                          name="readout")
        generator = SequenceGenerator(readout=readout,
                                      transition=transition,
                                      attention=attention,
                                      name="generator")

        cost = generator.cost(answer,
                              answer_mask.astype(theano.config.floatX),
                              attended=cenc,
                              attended_mask=context_mask.astype(
                                  theano.config.floatX),
                              name="cost")
        self.predictions = generator.generate(
            n_steps=7,
            batch_size=config.batch_size,
            attended=cenc,
            attended_mask=context_mask.astype(theano.config.floatX),
            iterate=True)[1]

        # Apply dropout
        cg = ComputationGraph([cost])

        if config.w_noise > 0:
            noise_vars = VariableFilter(roles=[WEIGHT])(cg)
            cg = apply_noise(cg, noise_vars, config.w_noise)
        if config.dropout > 0:
            cg = apply_dropout(cg, chidden_list, config.dropout)
        [cost_reg] = cg.outputs

        # Other stuff
        cost.name = 'cost'
        cost_reg.name = 'cost_reg'

        self.sgd_cost = cost_reg
        self.monitor_vars = [[cost_reg]]
        self.monitor_vars_valid = [[cost_reg]]

        # initialize new stuff manually (change!)
        generator.weights_init = IsotropicGaussian(0.01)
        generator.biases_init = Constant(0)
        generator.push_allocation_config()
        generator.push_initialization_config()
        transition.weights_init = Orthogonal()
        generator.initialize()

        # Initialize bricks
        embed.initialize()
        for brick in bricks:
            brick.weights_init = config.weights_init
            brick.biases_init = config.biases_init
            brick.initialize()