def architecture_children(self): children = self.raw_children() gate = children["gate"] transform = children["transform"] # prepare gates transform_gate = tn.SequentialNode( self.name + "_transformgate", [ gate, # add initial value as bias instead # TODO parameterize tn.AddConstantNode(self.name + "_biastranslation", value=-4), tn.SigmoidNode(self.name + "_transformgatesigmoid") ]) # carry gate = 1 - transform gate carry_gate = tn.SequentialNode(self.name + "_carrygate", [ tn.ReferenceNode(self.name + "_transformgateref", reference=transform_gate.name), tn.MultiplyConstantNode(self.name + "_invert", value=-1), tn.AddConstantNode(self.name + "_add", value=1) ]) # combine with gates gated_transform = tn.ElementwiseProductNode( self.name + "_gatedtransform", [transform_gate, transform]) gated_carry = tn.ElementwiseProductNode( self.name + "_gatedcarry", [carry_gate, tn.IdentityNode(self.name + "_carry")]) res = tn.ElementwiseSumNode(self.name + "_res", [gated_carry, gated_transform]) return [res]
def architecture_children(self): gate_node = tn.SequentialNode( self.name + "_gate_seq", [ batch_fold.AddAxisNode(self.name + "_add_axis", axis=2), batch_fold.FoldUnfoldAxisIntoBatchNode( self.name + "_batch_fold", # NOTE: using dnn conv, since pooling is normally strided # and the normal conv is slow with strides tn.DnnConv2DWithBiasNode(self.name + "_conv", num_filters=1), axis=1), batch_fold.RemoveAxisNode(self.name + "_remove_axis", axis=2), tn.SigmoidNode(self.name + "_gate_sigmoid") ]) inverse_gate_node = tn.SequentialNode(self.name + "_max_gate", [ tn.ReferenceNode(self.name + "_gate_ref", reference=gate_node.name), tn.MultiplyConstantNode(self.name + "_", value=-1), tn.AddConstantNode(self.name + "_add1", value=1) ]) mean_node = tn.ElementwiseProductNode( self.name + "_mean_product", [tn.MeanPool2DNode(self.name + "_mean_pool"), gate_node]) max_node = tn.ElementwiseProductNode( self.name + "_max_product", [tn.MaxPool2DNode(self.name + "_max_pool"), inverse_gate_node]) return [ tn.ElementwiseSumNode(self.name + "_sum", [mean_node, max_node]) ]
def irregular_length_attention_node(name, lengths_reference, num_units, output_units=None): """ NOTE: if output_units is not None, this should be the number of input units """ value_branch = UngroupIrregularLengthTensorsNode( name + "_ungroup_values", lengths_reference=lengths_reference) fc2_units = 1 if output_units is None else output_units attention_nodes = [ tn.DenseNode(name + "_fc1", num_units=num_units), tn.ScaledTanhNode(name + "_tanh"), tn.DenseNode(name + "_fc2", num_units=fc2_units), UngroupIrregularLengthTensorsNode(name + "_ungroup_attention", lengths_reference=lengths_reference), _IrregularLengthAttentionSoftmaxNode( name + "_softmax", lengths_reference=lengths_reference), ] if output_units is None: attention_nodes += [ tn.AddBroadcastNode(name + "_bcast", axes=(2, )), ] attention_branch = tn.SequentialNode(name + "_attention", attention_nodes) return tn.SequentialNode(name, [ tn.ElementwiseProductNode(name + "_prod", [value_branch, attention_branch]), tn.SumNode(name + "_sum", axis=1) ])
def forget_gate_conv_2d_node(name, num_filters, filter_size=(3, 3), initial_bias=0): return tn.ElementwiseProductNode(name, [ tn.IdentityNode(name + "_identity"), tn.SequentialNode(name + "_forget", [ tn.Conv2DWithBiasNode(name + "_conv", num_filters=num_filters, filter_size=filter_size, stride=(1, 1), pad="same"), tn.AddConstantNode(name + "_initial_bias", value=initial_bias), tn.SigmoidNode(name + "_sigmoid") ]) ])
def test_elementwise_product_node(): for s in [(), (3, 4, 5)]: network = tn.ElementwiseProductNode( "es", [tn.InputNode("i1", shape=s), tn.InputNode("i2", shape=s), tn.InputNode("i3", shape=s)], ).network() fn = network.function(["i1", "i2", "i3"], ["es"]) i1 = np.array(np.random.rand(*s), dtype=fX) i2 = np.array(np.random.rand(*s), dtype=fX) i3 = np.array(np.random.rand(*s), dtype=fX) np.testing.assert_allclose(i1 * i2 * i3, fn(i1, i2, i3)[0], rtol=1e-5)
def standard_tanh_spatial_attention_2d_node(name, num_filters, output_channels=None): """ NOTE: if output_channels is not None, this should be the number of input channels """ conv2_filters = 1 if output_channels is None else output_channels attention_nodes = [ tn.Conv2DWithBiasNode(name + "_conv1", filter_size=(1, 1), num_filters=num_filters), tn.ScaledTanhNode(name + "_tanh"), tn.Conv2DWithBiasNode(name + "_conv2", filter_size=(1, 1), num_filters=conv2_filters), tn.SpatialSoftmaxNode(name + "_softmax"), ] if output_channels is None: attention_nodes += [ tn.AddBroadcastNode(name + "_bcast", axes=(1,)), ] # multiply input by attention weights and sum across spatial dimensions nodes = [ tn.ElementwiseProductNode( name + "_combine", [tn.IdentityNode(name + "_input"), tn.SequentialNode( name + "_attention", attention_nodes )]), tn.FlattenNode(name + "_flatten", outdim=3), tn.SumNode(name + "_sum", axis=2), ] return tn.SequentialNode(name, nodes)
def test_elementwise_product_node_serialization(): tn.check_serialization(tn.ElementwiseProductNode("a", [])) tn.check_serialization(tn.ElementwiseProductNode( "a", [tn.ElementwiseProductNode("b", []), tn.ElementwiseProductNode("c", [])]))