예제 #1
0
def pretrain_model(epochs, batch_size, max_train_steps_per_epoch, save_dir):
    # step 1: prepare dataset
    train_data, test_data = load_data()
    pipeline = fe.Pipeline(
        train_data=train_data,
        batch_size=batch_size,
        ops=[
            PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x"),

            # augmentation 1
            RandomCrop(32, 32, image_in="x", image_out="x_aug"),
            Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"), prob=0.5),
            Sometimes(
                ColorJitter(inputs="x_aug", outputs="x_aug", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
                prob=0.8),
            Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2),
            Sometimes(GaussianBlur(inputs="x_aug", outputs="x_aug", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)),
                      prob=0.5),
            ToFloat(inputs="x_aug", outputs="x_aug"),

            # augmentation 2
            RandomCrop(32, 32, image_in="x", image_out="x_aug2"),
            Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"), prob=0.5),
            Sometimes(
                ColorJitter(inputs="x_aug2", outputs="x_aug2", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
                prob=0.8),
            Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2),
            Sometimes(GaussianBlur(inputs="x_aug2", outputs="x_aug2", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)),
                      prob=0.5),
            ToFloat(inputs="x_aug2", outputs="x_aug2")
        ])

    # step 2: prepare network
    model_con, model_finetune = fe.build(model_fn=ResNet9, optimizer_fn=["adam", "adam"])
    network = fe.Network(ops=[
        LambdaOp(lambda x, y: tf.concat([x, y], axis=0), inputs=["x_aug", "x_aug2"], outputs="x_com"),
        ModelOp(model=model_con, inputs="x_com", outputs="y_com"),
        LambdaOp(lambda x: tf.split(x, 2, axis=0), inputs="y_com", outputs=["y_pred", "y_pred2"]),
        NTXentOp(arg1="y_pred", arg2="y_pred2", outputs=["NTXent", "logit", "label"]),
        UpdateOp(model=model_con, loss_name="NTXent")
    ])

    # step 3: prepare estimator
    traces = [
        Accuracy(true_key="label", pred_key="logit", mode="train", output_name="contrastive_accuracy"),
        ModelSaver(model=model_con, save_dir=save_dir),
    ]
    estimator = fe.Estimator(pipeline=pipeline,
                             network=network,
                             epochs=epochs,
                             traces=traces,
                             max_train_steps_per_epoch=max_train_steps_per_epoch,
                             monitor_names="contrastive_accuracy")
    estimator.fit()

    return model_con, model_finetune
예제 #2
0
 def test_single_input_tf_static(self):
     a = LambdaOp(inputs='x', outputs='x', fn=lambda x: x + 1)
     b = LambdaOp(inputs='x', outputs='x', fn=lambda x: x + 5)
     oneof = OneOf(a, b)
     oneof.build('tf')
     output = oneof.forward(data=self.single_input_tf, state={})
     with self.subTest('Check output type'):
         self.assertTrue(tf.is_tensor(output))
     with self.subTest('Check output image shape'):
         self.assertEqual(output.shape, self.output_shape)
예제 #3
0
 def test_single_input_torch(self):
     a = LambdaOp(inputs='x', outputs='x', fn=lambda x: x + 1)
     b = LambdaOp(inputs='x', outputs='x', fn=lambda x: x + 5)
     oneof = OneOf(a, b)
     oneof.build('torch')
     output = oneof.forward(data=self.single_input_torch, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), torch.Tensor)
     with self.subTest('Check output image shape'):
         self.assertEqual(output.shape, self.output_shape)
예제 #4
0
 def test_single_input_tf_static(self):
     a = LambdaOp(inputs='x', outputs='y', fn=lambda x: x + 1, mode='test')
     b = LambdaOp(inputs=['y', 'z'], outputs='w', fn=lambda x, y: x + y, mode='test')
     fuse = Fuse([a, b])
     with self.subTest('Check op inputs'):
         self.assertListEqual(fuse.inputs, ['x', 'z'])
     with self.subTest('Check op outputs'):
         self.assertListEqual(fuse.outputs, ['y', 'w'])
     with self.subTest('Check op mode'):
         self.assertSetEqual(fuse.mode, {'test'})
     output = fuse.forward(data=self.multi_input_tf, state={"mode": "test", "deferred": {}})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output image shape'):
         self.assertEqual(output[0].shape, self.output_shape)
예제 #5
0
 def test_repeat_fn_exterior_value_static(self):
     add_op = LambdaOp(inputs='x',
                       outputs=('x', 'y'),
                       fn=lambda w: (w + 1, w * w),
                       mode='eval')
     repeat_op = Repeat(add_op, repeat=lambda y, z: y + z < 25)
     repeat_op.build('tf')
     with self.subTest('Check op inputs'):
         self.assertListEqual(repeat_op.inputs, ['x', 'z'])
     with self.subTest('Check op outputs'):
         self.assertListEqual(repeat_op.outputs, ['x', 'y'])
     with self.subTest('Check op mode'):
         self.assertSetEqual(repeat_op.mode, {'eval'})
     x = [tf.ones([1]), 10 + tf.ones([1])]
     output = tf.function(lambda y: repeat_op.forward(data=y,
                                                      state={
                                                          "deferred": {},
                                                          "mode": "eval"
                                                      }))
     output(x)
     output = output(x)
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output value (x)'):
         self.assertEqual(5, output[0])
     with self.subTest('Check output value (y)'):
         self.assertEqual(16, output[1])
예제 #6
0
 def test_single_repeat_fn_interior_value_static(self):
     add_op = LambdaOp(inputs='x',
                       outputs=('x', 'y'),
                       fn=lambda z: (z + 1, z * z),
                       mode='eval')
     repeat_op = Repeat(add_op, repeat=lambda y: y < 1)
     repeat_op.build('tf')
     with self.subTest('Check op inputs'):
         self.assertListEqual(repeat_op.inputs, ['x'])
     with self.subTest('Check op outputs'):
         self.assertListEqual(repeat_op.outputs, ['x', 'y'])
     with self.subTest('Check op mode'):
         self.assertSetEqual(repeat_op.mode, {'eval'})
     x = [tf.ones([1])]
     output = tf.function(lambda y: repeat_op.forward(data=y,
                                                      state={
                                                          "deferred": {},
                                                          "mode": "eval"
                                                      }))
     output(x)  # build the graph
     output = output(x)
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output value (x)'):
         self.assertEqual(2, output[0])
     with self.subTest('Check output value (y)'):
         self.assertEqual(1, output[1])
예제 #7
0
 def test_repeat_fn_exterior_value_tf(self):
     add_op = LambdaOp(inputs='x',
                       outputs=('x', 'y'),
                       fn=lambda x: (x + 1, x * x),
                       mode='eval')
     repeat_op = Repeat(add_op, repeat=lambda y, z: y + z < 25)
     repeat_op.build('tf')
     with self.subTest('Check op inputs'):
         self.assertListEqual(repeat_op.inputs, ['x', 'z'])
     with self.subTest('Check op outputs'):
         self.assertListEqual(repeat_op.outputs, ['x', 'y'])
     with self.subTest('Check op mode'):
         self.assertSetEqual(repeat_op.mode, {'eval'})
     with tf.GradientTape(persistent=True) as tape:
         output = repeat_op.forward(data=[tf.ones([1]), 10 + tf.ones([1])],
                                    state={
                                        "deferred": {},
                                        "mode": "eval",
                                        "tape": tape
                                    })
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output value (x)'):
         self.assertEqual(5, output[0])
     with self.subTest('Check output value (y)'):
         self.assertEqual(16, output[1])
예제 #8
0
def get_estimator(epochs=2, batch_size=32):
    # step 1
    train_data, eval_data = mnist.load_data()
    pipeline = fe.Pipeline(train_data=train_data,
                           eval_data=eval_data,
                           batch_size=batch_size,
                           ops=[
                               ExpandDims(inputs="x", outputs="x"),
                               Minmax(inputs="x", outputs="x")
                           ])
    # step 2
    model = fe.build(model_fn=LeNet, optimizer_fn="adam")
    network = fe.Network(ops=[
        ModelOp(model=model,
                inputs="x",
                outputs=["y_pred", "feature_vector"],
                intermediate_layers='dense'),
        CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
        CustomLoss(inputs=("feature_vector", "feature_selected"),
                   outputs="feature_loss"),
        LambdaOp(fn=lambda x, y: x + y,
                 inputs=("ce", "feature_loss"),
                 outputs="total_loss"),
        UpdateOp(model=model, loss_name="total_loss")
    ])
    # step 3
    traces = [
        MemoryBank(inputs=("feature_vector", "y"), outputs="feature_selected")
    ]
    estimator = fe.Estimator(pipeline=pipeline,
                             network=network,
                             epochs=epochs,
                             traces=traces)
    return estimator
예제 #9
0
 def test_single_input_torch(self):
     a = LambdaOp(inputs='x', outputs='x', fn=lambda x: x + 1)
     sometimes = Sometimes(a, prob=0.75)
     sometimes.build('torch')
     output = sometimes.forward(data=self.single_input_torch, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output image shape'):
         self.assertEqual(output[0].shape, self.output_shape)
예제 #10
0
 def test_multi_input_tf_static(self):
     a = LambdaOp(inputs=['x', 'y'],
                  outputs=['y', 'z'],
                  fn=lambda x, y: [x + y, x - y])
     b = LambdaOp(inputs=['x', 'y'],
                  outputs=['y', 'z'],
                  fn=lambda x, y: [x * y, x + y])
     c = LambdaOp(inputs=['x', 'y'],
                  outputs=['y', 'z'],
                  fn=lambda x, y: [y, x])
     oneof = OneOf(a, b, c)
     oneof.build('tf')
     output = oneof.forward(data=self.multi_input_tf, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output list length'):
         self.assertEqual(len(output), 2)
     for img_output in output:
         with self.subTest('Check output image shape'):
             self.assertEqual(img_output.shape, self.output_shape)
예제 #11
0
 def test_multi_input_tf_static(self):
     a = LambdaOp(inputs=['x', 'y'],
                  outputs=['y', 'x'],
                  fn=lambda x, y: [x + y, x - y])
     sometimes = Sometimes(a)
     sometimes.build('tf')
     output = sometimes.forward(data=self.multi_input_tf, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output list length'):
         self.assertEqual(len(output), 2)
     for img_output in output:
         with self.subTest('Check output image shape'):
             self.assertEqual(img_output.shape, self.output_shape)
예제 #12
0
 def test_multi_repeat_fn_interior_value_tf(self):
     add_op = LambdaOp(inputs='x', outputs=('x', 'y'), fn=lambda x: (x + 1, x * x), mode='eval')
     repeat_op = Repeat(add_op, repeat=lambda y: y < 25)
     repeat_op.build('tf')
     with self.subTest('Check op inputs'):
         self.assertListEqual(repeat_op.inputs, ['x'])
     with self.subTest('Check op outputs'):
         self.assertListEqual(repeat_op.outputs, ['x', 'y'])
     with self.subTest('Check op mode'):
         self.assertSetEqual(repeat_op.mode, {'eval'})
     output = repeat_op.forward(data=[tf.ones([1])], state={"deferred": {}, "mode": "eval"})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output value (x)'):
         self.assertEqual(6, output[0])
     with self.subTest('Check output value (y)'):
         self.assertEqual(25, output[1])
예제 #13
0
def pretrain_model(epochs, batch_size, train_steps_per_epoch, save_dir):
    train_data, test_data = load_data()
    pipeline = fe.Pipeline(
        train_data=train_data,
        batch_size=batch_size,
        ops=[
            PadIfNeeded(min_height=40,
                        min_width=40,
                        image_in="x",
                        image_out="x",
                        mode="train"),

            # augmentation 1
            RandomCrop(32, 32, image_in="x", image_out="x_aug"),
            Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"),
                      prob=0.5),
            Sometimes(ColorJitter(inputs="x_aug",
                                  outputs="x_aug",
                                  brightness=0.8,
                                  contrast=0.8,
                                  saturation=0.8,
                                  hue=0.2),
                      prob=0.8),
            Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2),
            Sometimes(GaussianBlur(inputs="x_aug",
                                   outputs="x_aug",
                                   blur_limit=(3, 3),
                                   sigma_limit=(0.1, 2.0)),
                      prob=0.5),
            ChannelTranspose(inputs="x_aug", outputs="x_aug"),
            ToFloat(inputs="x_aug", outputs="x_aug"),

            # augmentation 2
            RandomCrop(32, 32, image_in="x", image_out="x_aug2"),
            Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"),
                      prob=0.5),
            Sometimes(ColorJitter(inputs="x_aug2",
                                  outputs="x_aug2",
                                  brightness=0.8,
                                  contrast=0.8,
                                  saturation=0.8,
                                  hue=0.2),
                      prob=0.8),
            Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2),
            Sometimes(GaussianBlur(inputs="x_aug2",
                                   outputs="x_aug2",
                                   blur_limit=(3, 3),
                                   sigma_limit=(0.1, 2.0)),
                      prob=0.5),
            ChannelTranspose(inputs="x_aug2", outputs="x_aug2"),
            ToFloat(inputs="x_aug2", outputs="x_aug2")
        ])

    model_con = fe.build(model_fn=lambda: ResNet9OneLayerHead(length=128),
                         optimizer_fn="adam")
    network = fe.Network(ops=[
        LambdaOp(lambda x, y: torch.cat([x, y], dim=0),
                 inputs=["x_aug", "x_aug2"],
                 outputs="x_com"),
        ModelOp(model=model_con, inputs="x_com", outputs="y_com"),
        LambdaOp(lambda x: torch.chunk(x, 2, dim=0),
                 inputs="y_com",
                 outputs=["y_pred", "y_pred2"],
                 mode="train"),
        NTXentOp(arg1="y_pred",
                 arg2="y_pred2",
                 outputs=["NTXent", "logit", "label"],
                 mode="train"),
        UpdateOp(model=model_con, loss_name="NTXent")
    ])

    traces = [
        Accuracy(true_key="label",
                 pred_key="logit",
                 mode="train",
                 output_name="contrastive_accuracy"),
        ModelSaver(model=model_con, save_dir=save_dir)
    ]
    estimator = fe.Estimator(pipeline=pipeline,
                             network=network,
                             epochs=epochs,
                             traces=traces,
                             train_steps_per_epoch=train_steps_per_epoch)
    estimator.fit()

    return model_con
예제 #14
0
 def test_multi_input(self):
     op = LambdaOp(fn=tf.reshape)
     data = op.forward(data=[tf.convert_to_tensor([1, 2, 3, 4]), (2, 2)],
                       state={})
     self.assertTrue(is_equal(data, tf.convert_to_tensor([[1, 2], [3, 4]])))
예제 #15
0
 def test_single_input(self):
     op = LambdaOp(fn=tf.reduce_sum)
     data = op.forward(data=tf.convert_to_tensor([[1, 2, 3]]), state={})
     self.assertEqual(data, 6)