コード例 #1
0
 def test_single_input(self):
     minmax = Minmax(inputs='x', outputs='x')
     binarize = Binarize(inputs='x', outputs='x', threshold=1)
     oneof = OneOf(minmax, binarize)
     output = oneof.forward(data=self.single_input, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output image shape'):
         self.assertEqual(output[0].shape, self.output_shape)
コード例 #2
0
 def test_multi_input(self):
     minmax = Minmax(inputs='x', outputs='x')
     normalize = Normalize(inputs='x', outputs='x')
     binarize = Binarize(inputs='x', outputs='x', threshold=1)
     oneof = OneOf(minmax, normalize, binarize)
     output = oneof.forward(data=self.multi_input, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output list length'):
         self.assertEqual(len(output), 2)
     for img_output in output:
         with self.subTest('Check output image shape'):
             self.assertEqual(img_output.shape, self.output_shape)
コード例 #3
0
def get_estimator(batch_size=100,
                  epochs=20,
                  max_train_steps_per_epoch=None,
                  save_dir=tempfile.mkdtemp()):
    train_data, _ = load_data()
    pipeline = fe.Pipeline(train_data=train_data,
                           batch_size=batch_size,
                           ops=[
                               ExpandDims(inputs="x", outputs="x", axis=0),
                               Minmax(inputs="x", outputs="x"),
                               Binarize(inputs="x", outputs="x",
                                        threshold=0.5),
                           ])

    encode_model = fe.build(model_fn=EncoderNet,
                            optimizer_fn="adam",
                            model_name="encoder")
    decode_model = fe.build(model_fn=DecoderNet,
                            optimizer_fn="adam",
                            model_name="decoder")

    network = fe.Network(ops=[
        ModelOp(model=encode_model, inputs="x", outputs="meanlogvar"),
        SplitOp(inputs="meanlogvar", outputs=("mean", "logvar")),
        ReparameterizepOp(inputs=("mean", "logvar"), outputs="z"),
        ModelOp(model=decode_model, inputs="z", outputs="x_logit"),
        CrossEntropy(inputs=("x_logit", "x"), outputs="cross_entropy"),
        CVAELoss(inputs=("cross_entropy", "mean", "logvar", "z"),
                 outputs="loss"),
        UpdateOp(model=encode_model, loss_name="loss"),
        UpdateOp(model=decode_model, loss_name="loss"),
    ])

    traces = [
        BestModelSaver(model=encode_model, save_dir=save_dir),
        BestModelSaver(model=decode_model, save_dir=save_dir)
    ]

    estimator = fe.Estimator(
        pipeline=pipeline,
        network=network,
        epochs=epochs,
        traces=traces,
        max_train_steps_per_epoch=max_train_steps_per_epoch)

    return estimator
コード例 #4
0
 def test_multi_input(self):
     op = Binarize(threshold=1, inputs='x', outputs='x')
     data = op.forward(data=self.multi_input, state={})
     self.assertTrue(is_equal(data, self.multi_output))
コード例 #5
0
 def test_single_input(self):
     op = Binarize(threshold=1, inputs='x', outputs='x')
     data = op.forward(data=[np.array([1, 2, 3, 4])], state={})
     self.assertTrue(is_equal(data, self.single_output))