def test_single_input(self): minmax = Minmax(inputs='x', outputs='x') binarize = Binarize(inputs='x', outputs='x', threshold=1) oneof = OneOf(minmax, binarize) output = oneof.forward(data=self.single_input, state={}) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output image shape'): self.assertEqual(output[0].shape, self.output_shape)
def test_multi_input(self): minmax = Minmax(inputs='x', outputs='x') normalize = Normalize(inputs='x', outputs='x') binarize = Binarize(inputs='x', outputs='x', threshold=1) oneof = OneOf(minmax, normalize, binarize) output = oneof.forward(data=self.multi_input, state={}) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output list length'): self.assertEqual(len(output), 2) for img_output in output: with self.subTest('Check output image shape'): self.assertEqual(img_output.shape, self.output_shape)
def get_estimator(batch_size=100, epochs=20, max_train_steps_per_epoch=None, save_dir=tempfile.mkdtemp()): train_data, _ = load_data() pipeline = fe.Pipeline(train_data=train_data, batch_size=batch_size, ops=[ ExpandDims(inputs="x", outputs="x", axis=0), Minmax(inputs="x", outputs="x"), Binarize(inputs="x", outputs="x", threshold=0.5), ]) encode_model = fe.build(model_fn=EncoderNet, optimizer_fn="adam", model_name="encoder") decode_model = fe.build(model_fn=DecoderNet, optimizer_fn="adam", model_name="decoder") network = fe.Network(ops=[ ModelOp(model=encode_model, inputs="x", outputs="meanlogvar"), SplitOp(inputs="meanlogvar", outputs=("mean", "logvar")), ReparameterizepOp(inputs=("mean", "logvar"), outputs="z"), ModelOp(model=decode_model, inputs="z", outputs="x_logit"), CrossEntropy(inputs=("x_logit", "x"), outputs="cross_entropy"), CVAELoss(inputs=("cross_entropy", "mean", "logvar", "z"), outputs="loss"), UpdateOp(model=encode_model, loss_name="loss"), UpdateOp(model=decode_model, loss_name="loss"), ]) traces = [ BestModelSaver(model=encode_model, save_dir=save_dir), BestModelSaver(model=decode_model, save_dir=save_dir) ] estimator = fe.Estimator( pipeline=pipeline, network=network, epochs=epochs, traces=traces, max_train_steps_per_epoch=max_train_steps_per_epoch) return estimator
def test_multi_input(self): op = Binarize(threshold=1, inputs='x', outputs='x') data = op.forward(data=self.multi_input, state={}) self.assertTrue(is_equal(data, self.multi_output))
def test_single_input(self): op = Binarize(threshold=1, inputs='x', outputs='x') data = op.forward(data=[np.array([1, 2, 3, 4])], state={}) self.assertTrue(is_equal(data, self.single_output))