def build_graph(self): """ Building graph structures: """ self.m1_features = Feature(shape=(None, self.n_features)) self.m2_features = Feature(shape=(None, self.n_features)) prev_layer1 = self.m1_features prev_layer2 = self.m2_features for layer_size in self.layer_sizes: prev_layer1 = Dense( out_channels=layer_size, in_layers=[prev_layer1], activation_fn=tf.nn.relu) prev_layer2 = prev_layer1.shared([prev_layer2]) if self.dropout > 0.0: prev_layer1 = Dropout(self.dropout, in_layers=prev_layer1) prev_layer2 = Dropout(self.dropout, in_layers=prev_layer2) readout_m1 = Dense( out_channels=1, in_layers=[prev_layer1], activation_fn=None) readout_m2 = readout_m1.shared([prev_layer2]) self.add_output(Sigmoid(readout_m1) * 4 + 1) self.add_output(Sigmoid(readout_m2) * 4 + 1) self.difference = readout_m1 - readout_m2 label = Label(shape=(None, 1)) loss = HingeLoss(in_layers=[label, self.difference]) self.my_task_weights = Weights(shape=(None, 1)) loss = WeightedError(in_layers=[loss, self.my_task_weights]) self.set_loss(loss)
def test_shared_layer(self): n_data_points = 20 n_features = 2 X = np.random.rand(n_data_points, n_features) y1 = np.array([[0, 1] for x in range(n_data_points)]) X = NumpyDataset(X) ys = [NumpyDataset(y1)] databag = Databag() features = Feature(shape=(None, n_features)) databag.add_dataset(features, X) outputs = [] label = Label(shape=(None, 2)) dense1 = Dense(out_channels=2, in_layers=[features]) dense2 = dense1.shared(in_layers=[features]) output1 = SoftMax(in_layers=[dense1]) output2 = SoftMax(in_layers=[dense2]) smce = SoftMaxCrossEntropy(in_layers=[label, dense1]) outputs.append(output1) outputs.append(output2) databag.add_dataset(label, ys[0]) total_loss = ReduceMean(in_layers=[smce]) tg = dc.models.TensorGraph(learning_rate=0.01) for output in outputs: tg.add_output(output) tg.set_loss(total_loss) tg.fit_generator( databag.iterbatches(epochs=1, batch_size=tg.batch_size, pad_batches=True)) prediction = tg.predict_on_generator(databag.iterbatches()) assert_true(np.all(np.isclose(prediction[0], prediction[1], atol=0.01)))