Пример #1
0
 def test_single_input(self):
     to_gray = ToGray(inputs='x', outputs='x')
     output = to_gray.forward(data=self.single_input, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output image shape'):
         self.assertEqual(output[0].shape, self.single_output_shape)
Пример #2
0
def pretrain_model(epochs, batch_size, max_train_steps_per_epoch, save_dir):
    # step 1: prepare dataset
    train_data, test_data = load_data()
    pipeline = fe.Pipeline(
        train_data=train_data,
        batch_size=batch_size,
        ops=[
            PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x"),

            # augmentation 1
            RandomCrop(32, 32, image_in="x", image_out="x_aug"),
            Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"), prob=0.5),
            Sometimes(
                ColorJitter(inputs="x_aug", outputs="x_aug", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
                prob=0.8),
            Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2),
            Sometimes(GaussianBlur(inputs="x_aug", outputs="x_aug", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)),
                      prob=0.5),
            ToFloat(inputs="x_aug", outputs="x_aug"),

            # augmentation 2
            RandomCrop(32, 32, image_in="x", image_out="x_aug2"),
            Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"), prob=0.5),
            Sometimes(
                ColorJitter(inputs="x_aug2", outputs="x_aug2", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
                prob=0.8),
            Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2),
            Sometimes(GaussianBlur(inputs="x_aug2", outputs="x_aug2", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)),
                      prob=0.5),
            ToFloat(inputs="x_aug2", outputs="x_aug2")
        ])

    # step 2: prepare network
    model_con, model_finetune = fe.build(model_fn=ResNet9, optimizer_fn=["adam", "adam"])
    network = fe.Network(ops=[
        LambdaOp(lambda x, y: tf.concat([x, y], axis=0), inputs=["x_aug", "x_aug2"], outputs="x_com"),
        ModelOp(model=model_con, inputs="x_com", outputs="y_com"),
        LambdaOp(lambda x: tf.split(x, 2, axis=0), inputs="y_com", outputs=["y_pred", "y_pred2"]),
        NTXentOp(arg1="y_pred", arg2="y_pred2", outputs=["NTXent", "logit", "label"]),
        UpdateOp(model=model_con, loss_name="NTXent")
    ])

    # step 3: prepare estimator
    traces = [
        Accuracy(true_key="label", pred_key="logit", mode="train", output_name="contrastive_accuracy"),
        ModelSaver(model=model_con, save_dir=save_dir),
    ]
    estimator = fe.Estimator(pipeline=pipeline,
                             network=network,
                             epochs=epochs,
                             traces=traces,
                             max_train_steps_per_epoch=max_train_steps_per_epoch,
                             monitor_names="contrastive_accuracy")
    estimator.fit()

    return model_con, model_finetune
Пример #3
0
 def test_multi_input(self):
     to_gray = ToGray(inputs='x', outputs='x')
     output = to_gray.forward(data=self.multi_input, state={})
     with self.subTest('Check output type'):
         self.assertEqual(type(output), list)
     with self.subTest('Check output list length'):
         self.assertEqual(len(output), 2)
     for img_output in output:
         with self.subTest('Check output mask shape'):
             self.assertEqual(img_output.shape, self.multi_output_shape)
Пример #4
0
def pretrain_model(epochs, batch_size, train_steps_per_epoch, save_dir):
    train_data, test_data = load_data()
    pipeline = fe.Pipeline(
        train_data=train_data,
        batch_size=batch_size,
        ops=[
            PadIfNeeded(min_height=40,
                        min_width=40,
                        image_in="x",
                        image_out="x",
                        mode="train"),

            # augmentation 1
            RandomCrop(32, 32, image_in="x", image_out="x_aug"),
            Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"),
                      prob=0.5),
            Sometimes(ColorJitter(inputs="x_aug",
                                  outputs="x_aug",
                                  brightness=0.8,
                                  contrast=0.8,
                                  saturation=0.8,
                                  hue=0.2),
                      prob=0.8),
            Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2),
            Sometimes(GaussianBlur(inputs="x_aug",
                                   outputs="x_aug",
                                   blur_limit=(3, 3),
                                   sigma_limit=(0.1, 2.0)),
                      prob=0.5),
            ChannelTranspose(inputs="x_aug", outputs="x_aug"),
            ToFloat(inputs="x_aug", outputs="x_aug"),

            # augmentation 2
            RandomCrop(32, 32, image_in="x", image_out="x_aug2"),
            Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"),
                      prob=0.5),
            Sometimes(ColorJitter(inputs="x_aug2",
                                  outputs="x_aug2",
                                  brightness=0.8,
                                  contrast=0.8,
                                  saturation=0.8,
                                  hue=0.2),
                      prob=0.8),
            Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2),
            Sometimes(GaussianBlur(inputs="x_aug2",
                                   outputs="x_aug2",
                                   blur_limit=(3, 3),
                                   sigma_limit=(0.1, 2.0)),
                      prob=0.5),
            ChannelTranspose(inputs="x_aug2", outputs="x_aug2"),
            ToFloat(inputs="x_aug2", outputs="x_aug2")
        ])

    model_con = fe.build(model_fn=lambda: ResNet9OneLayerHead(length=128),
                         optimizer_fn="adam")
    network = fe.Network(ops=[
        LambdaOp(lambda x, y: torch.cat([x, y], dim=0),
                 inputs=["x_aug", "x_aug2"],
                 outputs="x_com"),
        ModelOp(model=model_con, inputs="x_com", outputs="y_com"),
        LambdaOp(lambda x: torch.chunk(x, 2, dim=0),
                 inputs="y_com",
                 outputs=["y_pred", "y_pred2"],
                 mode="train"),
        NTXentOp(arg1="y_pred",
                 arg2="y_pred2",
                 outputs=["NTXent", "logit", "label"],
                 mode="train"),
        UpdateOp(model=model_con, loss_name="NTXent")
    ])

    traces = [
        Accuracy(true_key="label",
                 pred_key="logit",
                 mode="train",
                 output_name="contrastive_accuracy"),
        ModelSaver(model=model_con, save_dir=save_dir)
    ]
    estimator = fe.Estimator(pipeline=pipeline,
                             network=network,
                             epochs=epochs,
                             traces=traces,
                             train_steps_per_epoch=train_steps_per_epoch)
    estimator.fit()

    return model_con