def feature_detector_blk(max_depth=2): """Input: node dict Output: TensorType([hyper.conv_dim, ]) Single patch of the conv. Depth is max_depth """ blk = td.Composition() with blk.scope(): nodes_in_patch = collect_node_for_conv_patch_blk( max_depth=max_depth).reads(blk.input) # map from python object to tensors mapped = td.Map( td.Record((coding_blk(), td.Scalar(), td.Scalar(), td.Scalar(), td.Scalar()))).reads(nodes_in_patch) # mapped = [(feature, idx, depth, max_depth), (...)] # compute weighted feature for each elem weighted = td.Map(weighted_feature_blk()).reads(mapped) # weighted = [fea, fea, fea, ...] # add together added = td.Reduce(td.Function(tf.add)).reads(weighted) # added = TensorType([hyper.conv_dim, ]) # add bias biased = td.Function(tf.add).reads(added, td.FromTensor(param.get('Bconv'))) # biased = TensorType([hyper.conv_dim, ]) # tanh tanh = td.Function(tf.nn.tanh).reads(biased) # tanh = TensorType([hyper.conv_dim, ]) blk.output.reads(tanh) return blk
def tree_sum_blk(loss_blk): # traverse the tree to sum up the loss tree_sum_fwd = td.ForwardDeclaration(td.PyObjectType(), td.TensorType([])) tree_sum = td.Composition() with tree_sum.scope(): myloss = loss_blk().reads(tree_sum.input) children = td.GetItem('children').reads(tree_sum.input) mapped = td.Map(tree_sum_fwd()).reads(children) summed = td.Reduce(td.Function(tf.add)).reads(mapped) summed = td.Function(tf.add).reads(summed, myloss) tree_sum.output.reads(summed) tree_sum_fwd.resolve_to(tree_sum) return tree_sum
def dynamic_pooling_blk(): """Input: root node dic Output: pooled, TensorType([hyper.conv_dim, ]) """ leaf_case = feature_detector_blk() pool_fwd = td.ForwardDeclaration(td.PyObjectType(), td.TensorType([ hyper.conv_dim, ])) pool = td.Composition() with pool.scope(): cur_fea = feature_detector_blk().reads(pool.input) children = td.GetItem('children').reads(pool.input) mapped = td.Map(pool_fwd()).reads(children) summed = td.Reduce(td.Function(tf.maximum)).reads(mapped) summed = td.Function(tf.maximum).reads(summed, cur_fea) pool.output.reads(summed) pool = td.OneOf(lambda x: x['clen'] == 0, {True: leaf_case, False: pool}) pool_fwd.resolve_to(pool) return pool
def reduce_net_block(): net_block = td.Concat() >> td.FC(20) >> td.FC(20) >> td.FC(1, activation=None) >> td.Function(lambda xs: tf.squeeze(xs, axis=1)) return td.Map(td.Scalar()) >> td.Reduce(net_block)