Пример #1
0
def feature_detector_blk(max_depth=2):
    """Input: node dict
    Output: TensorType([hyper.conv_dim, ])
    Single patch of the conv. Depth is max_depth
    """
    blk = td.Composition()
    with blk.scope():
        nodes_in_patch = collect_node_for_conv_patch_blk(
            max_depth=max_depth).reads(blk.input)

        # map from python object to tensors
        mapped = td.Map(
            td.Record((coding_blk(), td.Scalar(), td.Scalar(), td.Scalar(),
                       td.Scalar()))).reads(nodes_in_patch)
        # mapped = [(feature, idx, depth, max_depth), (...)]

        # compute weighted feature for each elem
        weighted = td.Map(weighted_feature_blk()).reads(mapped)
        # weighted = [fea, fea, fea, ...]

        # add together
        added = td.Reduce(td.Function(tf.add)).reads(weighted)
        # added = TensorType([hyper.conv_dim, ])

        # add bias
        biased = td.Function(tf.add).reads(added,
                                           td.FromTensor(param.get('Bconv')))
        # biased = TensorType([hyper.conv_dim, ])

        # tanh
        tanh = td.Function(tf.nn.tanh).reads(biased)
        # tanh = TensorType([hyper.conv_dim, ])

        blk.output.reads(tanh)
    return blk
Пример #2
0
def composed_embed_blk():
    leaf_case = direct_embed_blk()
    nonleaf_case = td.Composition(name='composed_embed_nonleaf')
    with nonleaf_case.scope():
        children = td.GetItem('children').reads(nonleaf_case.input)
        clen = td.Scalar().reads(td.GetItem('clen').reads(nonleaf_case.input))
        cclens = td.Map(td.GetItem('clen') >> td.Scalar()).reads(children)
        fchildren = td.Map(direct_embed_blk()).reads(children)

        initial_state = td.Composition()
        with initial_state.scope():
            initial_state.output.reads(
                td.FromTensor(tf.zeros(hyper.word_dim)),
                td.FromTensor(tf.zeros([])),
            )
        summed = td.Zip().reads(fchildren, cclens, td.Broadcast().reads(clen))
        summed = td.Fold(continous_weighted_add_blk(),
                         initial_state).reads(summed)[0]
        added = td.Function(tf.add, name='add_bias').reads(
            summed, td.FromTensor(param.get('B')))
        normed = clip_by_norm_blk().reads(added)

        act_fn = tf.nn.relu if hyper.use_relu else tf.nn.tanh
        relu = td.Function(act_fn).reads(normed)
        nonleaf_case.output.reads(relu)

    return td.OneOf(lambda node: node['clen'] == 0, {
        True: leaf_case,
        False: nonleaf_case
    })
Пример #3
0
def build_train_graph_for_RVAE(rvae_block, look_behind_length=0):
    token_emb_size = get_size_of_input_vecotrs(rvae_block)

    c = td.Composition()
    with c.scope():
        padded_input_sequence = td.Map(td.Vector(token_emb_size)).reads(
            c.input)
        network_output = rvae_block
        network_output.reads(padded_input_sequence)

        un_normalised_token_probs = td.GetItem(0).reads(network_output)
        mus_and_log_sigs = td.GetItem(1).reads(network_output)

        input_sequence = td.Slice(
            start=look_behind_length).reads(padded_input_sequence)
        # TODO: metric that output of rnn is the same as input sequence
        cross_entropy_loss = td.ZipWith(
            td.Function(softmax_crossentropy)) >> td.Mean()
        cross_entropy_loss.reads(un_normalised_token_probs, input_sequence)
        kl_loss = td.Function(kl_divergence)
        kl_loss.reads(mus_and_log_sigs)

        td.Metric('cross_entropy_loss').reads(cross_entropy_loss)
        td.Metric('kl_loss').reads(kl_loss)

        c.output.reads(td.Void())

    return c
Пример #4
0
    def add_metrics(is_root, is_neutral):
        """A block that adds metrics for loss and hits; output is the LSTM state."""
        c = td.Composition(name='predict(is_root=%s, is_neutral=%s)' %
                           (is_root, is_neutral))
        with c.scope():
            # destructure the input; (labels, (logits, state))
            labels = c.input[0]
            logits = td.GetItem(0).reads(c.input[1])
            state = td.GetItem(1).reads(c.input[1])

            # calculate loss
            loss = td.Function(tf_node_loss)
            td.Metric('all_loss').reads(loss.reads(logits, labels))
            if is_root: td.Metric('root_loss').reads(loss)

            # calculate fine-grained hits
            hits = td.Function(tf_fine_grained_hits)
            td.Metric('all_hits').reads(hits.reads(logits, labels))
            if is_root: td.Metric('root_hits').reads(hits)

            # calculate binary hits, if the label is not neutral
            if not is_neutral:
                binary_hits = td.Function(tf_binary_hits).reads(logits, labels)
                td.Metric('all_binary_hits').reads(binary_hits)
                if is_root: td.Metric('root_binary_hits').reads(binary_hits)

            # output the state, which will be read by our by parent's LSTM cell
            c.output.reads(state)
        return c
Пример #5
0
def add_metrics(is_root):
    c = td.Composition(name='predict(is_root=%s)' % (is_root))
    with c.scope():
        labels = c.input[0]
        logits = td.GetItem(0).reads(c.input[1])
        state = td.GetItem(1).reads(c.input[1])

        loss = td.Function(tf_node_loss)
        td.Metric('all_loss').reads(loss.reads(logits, labels))
        if is_root:
            td.Metric('root_loss').reads(loss)

        result_logits = td.Function(tf_logits)
        td.Metric('all_logits').reads(result_logits.reads(logits))
        if is_root:
            td.Metric('root_logits').reads(result_logits)
        # reserve pred and labels
        pred = td.Function(tf_pred)
        td.Metric('all_pred').reads(pred.reads(logits))
        if is_root:
            td.Metric('root_pred').reads(pred)
        answer = td.Function(tf_label)
        td.Metric('all_labels').reads(answer.reads(labels))
        if is_root:
            td.Metric('root_label').reads(answer)

        c.output.reads(state)
    return c
Пример #6
0
def tree_sum_blk(loss_blk):
    # traverse the tree to sum up the loss
    tree_sum_fwd = td.ForwardDeclaration(td.PyObjectType(), td.TensorType([]))
    tree_sum = td.Composition()
    with tree_sum.scope():
        myloss = loss_blk().reads(tree_sum.input)
        children = td.GetItem('children').reads(tree_sum.input)

        mapped = td.Map(tree_sum_fwd()).reads(children)
        summed = td.Reduce(td.Function(tf.add)).reads(mapped)
        summed = td.Function(tf.add).reads(summed, myloss)
        tree_sum.output.reads(summed)
    tree_sum_fwd.resolve_to(tree_sum)
    return tree_sum
Пример #7
0
    def set_metrics(self, train=True):
        """A block that adds metrics for loss and hits;
           output is the LSTM state."""
        c = td.Composition(
            name='predict')
        with c.scope():
            # destructure the input; (labels, logits)
            labels = c.input[0]
            logits = c.input[1]

            # calculate loss
            loss = td.Function(self.tf_node_loss)
            td.Metric('root_loss').reads(loss.reads(logits, labels))

            hits = td.Function(self.tf_fine_grained_hits)
            td.Metric('root_hits').reads(hits.reads(logits, labels))

            c.output.reads(logits)
        return c
Пример #8
0
def coding_blk():
    """Input: node dict
    Output: TensorType([1, hyper.word_dim])
    """
    Wcomb1 = param.get('Wcomb1')
    Wcomb2 = param.get('Wcomb2')

    blk = td.Composition()
    with blk.scope():
        direct = embedding.direct_embed_blk().reads(blk.input)
        composed = embedding.composed_embed_blk().reads(blk.input)
        Wcomb1 = td.FromTensor(param.get('Wcomb1'))
        Wcomb2 = td.FromTensor(param.get('Wcomb2'))

        direct = td.Function(embedding.batch_mul).reads(direct, Wcomb1)
        composed = td.Function(embedding.batch_mul).reads(composed, Wcomb2)

        added = td.Function(tf.add).reads(direct, composed)
        blk.output.reads(added)
    return blk
Пример #9
0
def dynamic_pooling_blk():
    """Input: root node dic
    Output: pooled, TensorType([hyper.conv_dim, ])
    """
    leaf_case = feature_detector_blk()

    pool_fwd = td.ForwardDeclaration(td.PyObjectType(),
                                     td.TensorType([
                                         hyper.conv_dim,
                                     ]))
    pool = td.Composition()
    with pool.scope():
        cur_fea = feature_detector_blk().reads(pool.input)
        children = td.GetItem('children').reads(pool.input)

        mapped = td.Map(pool_fwd()).reads(children)
        summed = td.Reduce(td.Function(tf.maximum)).reads(mapped)
        summed = td.Function(tf.maximum).reads(summed, cur_fea)
        pool.output.reads(summed)
    pool = td.OneOf(lambda x: x['clen'] == 0, {True: leaf_case, False: pool})
    pool_fwd.resolve_to(pool)
    return pool
Пример #10
0
def l2loss_blk():
    # rewrite using metric
    leaf_case = td.Composition()
    with leaf_case.scope():
        leaf_case.output.reads(td.FromTensor(tf.constant(1.)))
    nonleaf_case = td.Composition()
    with nonleaf_case.scope():
        direct = direct_embed_blk().reads(nonleaf_case.input)
        com = composed_embed_blk().reads(nonleaf_case.input)
        loss = td.Function(batch_nn_l2loss).reads(direct, com)
        nonleaf_case.output.reads(loss)
    return td.OneOf(lambda node: node['clen'] != 0, {
        False: leaf_case,
        True: nonleaf_case
    })
Пример #11
0
def continous_weighted_add_blk():
    block = td.Composition(name='continous_weighted_add')
    with block.scope():
        initial = td.GetItem(0).reads(block.input)
        cur = td.GetItem(1).reads(block.input)

        last = td.GetItem(0).reads(initial)
        idx = td.GetItem(1).reads(initial)

        cur_fea = td.GetItem(0).reads(cur)
        cur_clen = td.GetItem(1).reads(cur)
        pclen = td.GetItem(2).reads(cur)

        Wi = linear_combine_blk().reads(cur_clen, pclen, idx)

        weighted_fea = td.Function(batch_mul).reads(cur_fea, Wi)

        block.output.reads(
            td.Function(tf.add, name='add_last_weighted_fea').reads(
                last, weighted_fea),
            # XXX: rewrite using tf.range
            td.Function(tf.add, name='add_idx_1').reads(
                idx, td.FromTensor(tf.constant(1.))))
    return block
Пример #12
0
def weighted_feature_blk():
    """Input: (feature                       , idx   , pclen,  depth,  max_depth)
              (TensorType([hyper.word_dim, ]), Scalar, Scalar, Scalar, Scalar)
    Output: weighted_feature
            TensorType([hyper.conv_dim, ])
    """
    blk = td.Composition()
    with blk.scope():
        fea = blk.input[0]
        Wi = tri_combined_blk().reads(blk.input[1], blk.input[2], blk.input[3],
                                      blk.input[4])

        weighted_fea = td.Function(embedding.batch_mul).reads(fea, Wi)

        blk.output.reads(weighted_fea)
    return blk
Пример #13
0
def bidirectional_dynamic_CONV(fw_cell, bw_cell, out_features=64):
    bidir_conv_lstm = td.Composition()
    with bidir_conv_lstm.scope():
        fw_seq = td.Identity().reads(bidir_conv_lstm.input[0])
        labels = (
            td.GetItem(1) >> td.Map(td.Metric("labels")) >> td.Void()).reads(
                bidir_conv_lstm.input)
        bw_seq = td.Slice(step=-1).reads(fw_seq)

        forward_dir = (td.RNN(fw_cell) >> td.GetItem(0)).reads(fw_seq)
        back_dir = (td.RNN(bw_cell) >> td.GetItem(0)).reads(bw_seq)
        back_to_leftright = td.Slice(step=-1).reads(back_dir)

        output_transform = (
            td.Function(lambda x: tf.reshape(x, [-1, vsize * out_features])) >>
            td.FC(1, activation=None))

        bidir_common = (td.ZipWith(
            td.Concat() >> output_transform >> td.Metric('logits'))).reads(
                forward_dir, back_to_leftright)

        bidir_conv_lstm.output.reads(bidir_common)
    return bidir_conv_lstm
Пример #14
0
def linear_combine_blk():
    blk = td.Function(linear_combine, infer_output_type=False)
    blk.set_output_type(td.TensorType([hyper.word_dim, hyper.word_dim]))
    return blk
Пример #15
0
def tri_combined_blk():
    blk = td.Function(tri_combined, infer_output_type=False)
    blk.set_output_type(td.TensorType([hyper.word_dim, hyper.conv_dim]))
    return blk
Пример #16
0
 def CNN_Window3(filters):
     return td.Function(lambda a, b, c: cnn_operation([a,b,c],filters))
Пример #17
0
    def __init__(self,
                 image_data_batch,
                 image_mean,
                 text_seq_batch,
                 seq_length_batch,
                 T_decoder,
                 num_vocab_txt,
                 embed_dim_txt,
                 num_vocab_nmn,
                 embed_dim_nmn,
                 lstm_dim,
                 num_layers,
                 assembler,
                 encoder_dropout,
                 decoder_dropout,
                 decoder_sampling,
                 num_choices,
                 use_qpn,
                 qpn_dropout,
                 reduce_visfeat_dim=False,
                 new_visfeat_dim=128,
                 use_gt_layout=None,
                 gt_layout_batch=None,
                 map_dim=1024,
                 scope='neural_module_network',
                 reuse=None):

        with tf.variable_scope(scope, reuse=reuse):
            # Part 0: Visual feature from CNN
            with tf.variable_scope('image_feature_cnn'):
                image_data_batch = image_data_batch / 255.0 - image_mean
                image_feat_grid = nlvr_convnet(image_data_batch)
                self.image_feat_grid = image_feat_grid
            # Part 1: Seq2seq RNN to generate module layout tokensa
            with tf.variable_scope('layout_generation'):
                att_seq2seq = AttentionSeq2Seq(
                    text_seq_batch, seq_length_batch, T_decoder, num_vocab_txt,
                    embed_dim_txt, num_vocab_nmn, embed_dim_nmn, lstm_dim,
                    num_layers, assembler, encoder_dropout, decoder_dropout,
                    decoder_sampling, use_gt_layout, gt_layout_batch)
                self.att_seq2seq = att_seq2seq
                predicted_tokens = att_seq2seq.predicted_tokens
                token_probs = att_seq2seq.token_probs
                word_vecs = att_seq2seq.word_vecs
                neg_entropy = att_seq2seq.neg_entropy
                self.atts = att_seq2seq.atts

                self.predicted_tokens = predicted_tokens
                self.token_probs = token_probs
                self.word_vecs = word_vecs
                self.neg_entropy = neg_entropy

                # log probability of each generated sequence
                self.log_seq_prob = tf.reduce_sum(tf.log(token_probs), axis=0)

            # Part 2: Neural Module Network
            with tf.variable_scope('layout_execution'):
                modules = Modules(image_feat_grid, word_vecs, None,
                                  num_choices, map_dim)
                self.modules = modules
                # Recursion of modules
                att_shape = image_feat_grid.get_shape().as_list()[1:-1] + [1]
                # Forward declaration of module recursion
                att_expr_decl = td.ForwardDeclaration(td.PyObjectType(),
                                                      td.TensorType(att_shape))
                # _Find
                case_find = td.Record([('time_idx', td.Scalar(dtype='int32')),
                                       ('batch_idx', td.Scalar(dtype='int32'))
                                       ])
                case_find = case_find >> td.Function(modules.FindModule)
                # _Transform
                case_transform = td.Record([('input_0', att_expr_decl()),
                                            ('time_idx', td.Scalar('int32')),
                                            ('batch_idx', td.Scalar('int32'))])
                case_transform = case_transform >> td.Function(
                    modules.TransformModule)
                # _And
                case_and = td.Record([('input_0', att_expr_decl()),
                                      ('input_1', att_expr_decl()),
                                      ('time_idx', td.Scalar('int32')),
                                      ('batch_idx', td.Scalar('int32'))])
                case_and = case_and >> td.Function(modules.AndModule)
                # _Describe
                case_describe = td.Record([('input_0', att_expr_decl()),
                                           ('time_idx', td.Scalar('int32')),
                                           ('batch_idx', td.Scalar('int32'))])
                case_describe = case_describe >> \
                    td.Function(modules.DescribeModule)

                recursion_cases = td.OneOf(
                    td.GetItem('module'), {
                        '_Find': case_find,
                        '_Transform': case_transform,
                        '_And': case_and
                    })
                att_expr_decl.resolve_to(recursion_cases)

                # For invalid expressions, define a dummy answer
                # so that all answers have the same form
                dummy_scores = td.Void() >> td.FromTensor(
                    np.zeros(num_choices, np.float32))
                output_scores = td.OneOf(td.GetItem('module'), {
                    '_Describe': case_describe,
                    INVALID_EXPR: dummy_scores
                })

                # compile and get the output scores
                self.compiler = td.Compiler.create(output_scores)
                self.scores_nmn = self.compiler.output_tensors[0]

            # Add a question prior network if specified
            self.use_qpn = use_qpn
            self.qpn_dropout = qpn_dropout
            if use_qpn:
                self.scores_qpn = question_prior_net(
                    att_seq2seq.encoder_states, num_choices, qpn_dropout)
                self.scores = self.scores_nmn + self.scores_qpn
                #self.scores = self.scores_nmn
            else:
                self.scores = self.scores_nmn

            # Regularization: Entropy + L2
            self.entropy_reg = tf.reduce_mean(neg_entropy)
            #tf.check_numerics(self.entropy_reg, 'entropy NaN/Inf ')
            #print(self.entropy_reg.eval())
            module_weights = [
                v for v in tf.trainable_variables()
                if (scope in v.op.name and v.op.name.endswith('weights'))
            ]
            self.l2_reg = tf.add_n([tf.nn.l2_loss(v) for v in module_weights])
Пример #18
0
        forward_dir = (td.RNN(fw_cell) >> td.GetItem(0)).reads(fw_seq)
        back_dir = (td.RNN(bw_cell) >> td.GetItem(0)).reads(bw_seq)
        back_to_leftright = td.Slice(step=-1).reads(back_dir)

        output_transform = td.FC(1, activation=None)

        bidir_common = (td.ZipWith(
            td.Concat() >> output_transform >> td.Metric('logits'))).reads(
                forward_dir, back_to_leftright)

        bidir_conv_lstm.output.reads(bidir_common)
    return bidir_conv_lstm


CONV_data = td.Record((td.Map(
    td.Vector(vsize) >> td.Function(lambda x: tf.reshape(x, [-1, vsize, 1]))),
                       td.Map(td.Scalar())))
CONV_model = (CONV_data >> bidirectional_dynamic_CONV(
    multi_convLSTM_cell([vsize, vsize, vsize], [100, 100, 100]),
    multi_convLSTM_cell([vsize, vsize, vsize], [100, 100, 100])) >> td.Void())

FC_data = td.Record((td.Map(td.Vector(vsize)), td.Map(td.Scalar())))
FC_model = (FC_data >> bidirectional_dynamic_FC(multi_FC_cell(
    [1000] * 5), multi_FC_cell([1000] * 5), 1000) >> td.Void())

store = data(FLAGS.data_dir + FLAGS.data_type, FLAGS.truncate)

if FLAGS.model == "lstm":
    model = FC_model
elif FLAGS.model == "convlstm":
    model = CONV_model
Пример #19
0
def resampling_block(z_size):
    reparam_z = td.Function(resampling, name='resampling')
    reparam_z.set_input_type(td.TensorType((2 * z_size, )))
    reparam_z.set_output_type(td.TensorType((z_size, )))
    return reparam_z
Пример #20
0
def direct_embed_blk():
    return (td.GetItem('name') >> td.Scalar('int32') >>
            td.Function(lambda x: tf.nn.embedding_lookup(param.get('We'), x))
            >> clip_by_norm_blk())
Пример #21
0
def clip_by_norm_blk(norm=1.0):
    return td.Function(lambda x: tf.clip_by_norm(x, norm, axes=[1]))
Пример #22
0
def build_VAE(z_size, token_emb_size):
    c = td.Composition()
    c.set_input_type(td.SequenceType(td.TensorType(([token_emb_size]), 'float32')))
    with c.scope():
        # input_sequence = td.Map(td.Vector(token_emb_size)).reads(c.input)
        input_sequence = c.input

        # encoder composition TODO: refactor this out
        # rnn_cell = td.ScopedLayer(
        #     tf.contrib.rnn.LSTMCell(
        #         num_units=2*z_size,
        #         initializer=tf.contrib.layers.xavier_initializer(),
        #         activation=tf.tanh
        #     ),
        #     'encoder'
        # )
        encoder_rnn_cell = td.ScopedLayer(
            tf.contrib.rnn.GRUCell(
                num_units=2*z_size,
                # initializer=tf.contrib.layers.xavier_initializer(),
                activation=tf.tanh
            ),
            'encoder'
        )
        output_sequence = td.RNN(encoder_rnn_cell) >> td.GetItem(0)
        mus_and_log_sigs = output_sequence >> td.GetItem(-1)

        # reparam_z = mus_and_log_sigs >> td.Function(resampling)
        reparam_z = td.Function(resampling, name='resampling')
        reparam_z.set_input_type(td.TensorType((2 * z_size,)))
        reparam_z.set_output_type(td.TensorType((z_size,)))

        #  A list of same length of input_sequence, but with empty values
        #  this is used for the decoder to map over
        list_of_nothing = td.Map(
            td.Void() >> td.FromTensor(tf.zeros((0,)))
        )

        # decoder composition
        # TODO: refactor this out
        # decoder_rnn = td.ScopedLayer(
        #     tf.contrib.rnn.LSTMCell(
        #         num_units=z_size,
        #         initializer=tf.contrib.layers.xavier_initializer(),
        #         activation=tf.tanh
        #     ),
        #     'decoder'
        # )
        decoder_rnn = td.ScopedLayer(
            tf.contrib.rnn.GRUCell(
                num_units=z_size,
                # initializer=tf.contrib.layers.xavier_initializer(),
                activation=tf.tanh
            ),
            'decoder'
        )
        decoder_rnn_output = td.RNN(
            decoder_rnn,
            initial_state_from_input=True
        ) >> td.GetItem(0)

        fc_layer = td.FC(
            token_emb_size,
            activation=tf.nn.relu,
            initializer=tf.contrib.layers.xavier_initializer()
        )

        un_normalised_token_probs = decoder_rnn_output >> td.Map(fc_layer)

        # reparam_z.reads(input_sequence)
        mus_and_log_sigs.reads(input_sequence)
        reparam_z.reads(mus_and_log_sigs)
        list_of_nothing.reads(input_sequence)
        un_normalised_token_probs.reads(list_of_nothing, reparam_z)

        c.output.reads(un_normalised_token_probs, mus_and_log_sigs)
    return c
Пример #23
0
def reduce_net_block():
    net_block = td.Concat() >> td.FC(20) >> td.FC(20) >> td.FC(1, activation=None) >> td.Function(lambda xs: tf.squeeze(xs, axis=1))
    return td.Map(td.Scalar()) >> td.Reduce(net_block)
Пример #24
0
def expand_dim_blk(axis):
    return td.Function(lambda tensor: tf.expand_dims(tensor, axis=axis))
Пример #25
0
    def __init__(self, image_feat_grid, text_seq_batch, seq_length_batch,
        T_decoder, num_vocab_txt, embed_dim_txt, num_vocab_nmn,
        embed_dim_nmn, lstm_dim, num_layers, assembler,
        encoder_dropout, decoder_dropout, decoder_sampling,
        num_choices, use_qpn, qpn_dropout, reduce_visfeat_dim=False, new_visfeat_dim=256,
        use_gt_layout=None, gt_layout_batch=None,
        scope='neural_module_network', reuse=None):

        with tf.variable_scope(scope, reuse=reuse):
            # Part 0: Visual feature from CNN
            self.reduce_visfeat_dim = reduce_visfeat_dim
            if reduce_visfeat_dim:
                # use an extrac linear 1x1 conv layer (without ReLU)
                # to reduce the feature dimension
                with tf.variable_scope('reduce_visfeat_dim'):
                    image_feat_grid = conv('conv_reduce_visfeat_dim',
                        image_feat_grid, kernel_size=1, stride=1,
                        output_dim=new_visfeat_dim)
                print('visual feature dimension reduced to %d' % new_visfeat_dim)
            self.image_feat_grid = image_feat_grid

            # Part 1: Seq2seq RNN to generate module layout tokensa
            with tf.variable_scope('layout_generation'):
                att_seq2seq = AttentionSeq2Seq(text_seq_batch,
                    seq_length_batch, T_decoder, num_vocab_txt,
                    embed_dim_txt, num_vocab_nmn, embed_dim_nmn, lstm_dim,
                    num_layers, assembler, encoder_dropout, decoder_dropout,
                    decoder_sampling, use_gt_layout, gt_layout_batch)
                self.att_seq2seq = att_seq2seq
                predicted_tokens = att_seq2seq.predicted_tokens
                token_probs = att_seq2seq.token_probs
                word_vecs = att_seq2seq.word_vecs
                neg_entropy = att_seq2seq.neg_entropy
                self.atts = att_seq2seq.atts

                self.predicted_tokens = predicted_tokens
                self.token_probs = token_probs
                self.word_vecs = word_vecs
                self.neg_entropy = neg_entropy

                # log probability of each generated sequence
                self.log_seq_prob = tf.reduce_sum(tf.log(token_probs), axis=0)

            # Part 2: Neural Module Network
            with tf.variable_scope('layout_execution'):
                modules = Modules(image_feat_grid, word_vecs, None, num_choices)
                self.modules = modules
                # Recursion of modules
                att_shape = image_feat_grid.get_shape().as_list()[1:-1] + [1]
                # Forward declaration of module recursion
                att_expr_decl = td.ForwardDeclaration(td.PyObjectType(), td.TensorType(att_shape))
                # _Scene
                case_scene = td.Record([('time_idx', td.Scalar(dtype='int32')),
                                       ('batch_idx', td.Scalar(dtype='int32'))])
                case_scene = case_scene >> td.Function(modules.SceneModule)
                # _Find
                case_find = td.Record([('time_idx', td.Scalar(dtype='int32')),
                                       ('batch_idx', td.Scalar(dtype='int32'))])
                case_find = case_find >> td.Function(modules.FindModule)
                # _Filter
                case_filter = td.Record([('input_0', att_expr_decl()),
                                         ('time_idx', td.Scalar(dtype='int32')),
                                         ('batch_idx', td.Scalar(dtype='int32'))])
                case_filter = case_filter >> td.Function(modules.FilterModule)
                # _FindSameProperty
                case_find_same_property = td.Record([('input_0', att_expr_decl()),
                                                     ('time_idx', td.Scalar(dtype='int32')),
                                                     ('batch_idx', td.Scalar(dtype='int32'))])
                case_find_same_property = case_find_same_property >> \
                    td.Function(modules.FindSamePropertyModule)
                # _Transform
                case_transform = td.Record([('input_0', att_expr_decl()),
                                            ('time_idx', td.Scalar('int32')),
                                            ('batch_idx', td.Scalar('int32'))])
                case_transform = case_transform >> td.Function(modules.TransformModule)
                # _And
                case_and = td.Record([('input_0', att_expr_decl()),
                                      ('input_1', att_expr_decl()),
                                      ('time_idx', td.Scalar('int32')),
                                      ('batch_idx', td.Scalar('int32'))])
                case_and = case_and >> td.Function(modules.AndModule)
                # _Or
                case_or = td.Record([('input_0', att_expr_decl()),
                                     ('input_1', att_expr_decl()),
                                     ('time_idx', td.Scalar('int32')),
                                     ('batch_idx', td.Scalar('int32'))])
                case_or = case_or >> td.Function(modules.OrModule)
                # _Exist
                case_exist = td.Record([('input_0', att_expr_decl()),
                                        ('time_idx', td.Scalar('int32')),
                                        ('batch_idx', td.Scalar('int32'))])
                case_exist = case_exist >> td.Function(modules.ExistModule)
                # _Count
                case_count = td.Record([('input_0', att_expr_decl()),
                                        ('time_idx', td.Scalar('int32')),
                                        ('batch_idx', td.Scalar('int32'))])
                case_count = case_count >> td.Function(modules.CountModule)
                # _EqualNum
                case_equal_num = td.Record([('input_0', att_expr_decl()),
                                            ('input_1', att_expr_decl()),
                                            ('time_idx', td.Scalar('int32')),
                                            ('batch_idx', td.Scalar('int32'))])
                case_equal_num = case_equal_num >> td.Function(modules.EqualNumModule)
                # _MoreNum
                case_more_num = td.Record([('input_0', att_expr_decl()),
                                            ('input_1', att_expr_decl()),
                                            ('time_idx', td.Scalar('int32')),
                                            ('batch_idx', td.Scalar('int32'))])
                case_more_num = case_more_num >> td.Function(modules.MoreNumModule)
                # _LessNum
                case_less_num = td.Record([('input_0', att_expr_decl()),
                                            ('input_1', att_expr_decl()),
                                            ('time_idx', td.Scalar('int32')),
                                            ('batch_idx', td.Scalar('int32'))])
                case_less_num = case_less_num >> td.Function(modules.LessNumModule)
                # _SameProperty
                case_same_property = td.Record([('input_0', att_expr_decl()),
                                                ('input_1', att_expr_decl()),
                                                ('time_idx', td.Scalar('int32')),
                                                ('batch_idx', td.Scalar('int32'))])
                case_same_property = case_same_property >> \
                    td.Function(modules.SamePropertyModule)
                # _Describe
                case_describe = td.Record([('input_0', att_expr_decl()),
                                           ('time_idx', td.Scalar('int32')),
                                           ('batch_idx', td.Scalar('int32'))])
                case_describe = case_describe >> \
                    td.Function(modules.DescribeModule)

                recursion_cases = td.OneOf(td.GetItem('module'), {
                    '_Scene': case_scene,
                    '_Find': case_find,
                    '_Filter': case_filter,
                    '_FindSameProperty': case_find_same_property,
                    '_Transform': case_transform,
                    '_And': case_and,
                    '_Or': case_or})
                att_expr_decl.resolve_to(recursion_cases)

                # For invalid expressions, define a dummy answer
                # so that all answers have the same form
                dummy_scores = td.Void() >> td.FromTensor(np.zeros(num_choices, np.float32))
                output_scores = td.OneOf(td.GetItem('module'), {
                    '_Exist': case_exist,
                    '_Count': case_count,
                    '_EqualNum': case_equal_num,
                    '_MoreNum': case_more_num,
                    '_LessNum': case_less_num,
                    '_SameProperty': case_same_property,
                    '_Describe': case_describe,
                    INVALID_EXPR: dummy_scores})

                # compile and get the output scores
                self.compiler = td.Compiler.create(output_scores)
                self.scores_nmn = self.compiler.output_tensors[0]

            # Add a question prior network if specified
            self.use_qpn = use_qpn
            self.qpn_dropout = qpn_dropout
            if use_qpn:
                self.scores_qpn = question_prior_net(att_seq2seq.encoder_states,
                                                     num_choices, qpn_dropout)
                self.scores = self.scores_nmn + self.scores_qpn
            else:
                self.scores = self.scores_nmn

            # Regularization: Entropy + L2
            self.entropy_reg = tf.reduce_mean(neg_entropy)
            module_weights = [v for v in tf.trainable_variables()
                              if (scope in v.op.name and
                                  v.op.name.endswith('weights'))]
            self.l2_reg = tf.add_n([tf.nn.l2_loss(v) for v in module_weights])
Пример #26
0

c = td.Composition()
with c.scope():
    input_sequence = td.Map(td.Vector(54)).reads(c.input)

    # net = build_VAE(Z_SIZE, 54)
    # un_normalised_token_probs, mus_and_log_sigs = input_sequence >> build_VAE(Z_SIZE, 54)
    network_output = build_VAE(Z_SIZE, 54)

    network_output.reads(input_sequence)

    un_normalised_token_probs = td.GetItem(0).reads(network_output)
    mus_and_log_sigs = td.GetItem(1).reads(network_output)

    cross_entropy_loss = td.ZipWith(td.Function(softmax_crossentropy)) >> td.Mean()
    cross_entropy_loss.reads(
        un_normalised_token_probs,
        input_sequence
    )
    kl_loss = td.Function(kl_divergence)
    kl_loss.reads(mus_and_log_sigs)

    td.Metric('cross_entropy_loss').reads(cross_entropy_loss)
    td.Metric('kl_loss').reads(kl_loss)

    c.output.reads(td.Void())



#  Tokenised version of my code