def test_single_input(self): to_gray = ToGray(inputs='x', outputs='x') output = to_gray.forward(data=self.single_input, state={}) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output image shape'): self.assertEqual(output[0].shape, self.single_output_shape)
def pretrain_model(epochs, batch_size, max_train_steps_per_epoch, save_dir): # step 1: prepare dataset train_data, test_data = load_data() pipeline = fe.Pipeline( train_data=train_data, batch_size=batch_size, ops=[ PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x"), # augmentation 1 RandomCrop(32, 32, image_in="x", image_out="x_aug"), Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"), prob=0.5), Sometimes( ColorJitter(inputs="x_aug", outputs="x_aug", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), prob=0.8), Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2), Sometimes(GaussianBlur(inputs="x_aug", outputs="x_aug", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5), ToFloat(inputs="x_aug", outputs="x_aug"), # augmentation 2 RandomCrop(32, 32, image_in="x", image_out="x_aug2"), Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"), prob=0.5), Sometimes( ColorJitter(inputs="x_aug2", outputs="x_aug2", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), prob=0.8), Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2), Sometimes(GaussianBlur(inputs="x_aug2", outputs="x_aug2", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5), ToFloat(inputs="x_aug2", outputs="x_aug2") ]) # step 2: prepare network model_con, model_finetune = fe.build(model_fn=ResNet9, optimizer_fn=["adam", "adam"]) network = fe.Network(ops=[ LambdaOp(lambda x, y: tf.concat([x, y], axis=0), inputs=["x_aug", "x_aug2"], outputs="x_com"), ModelOp(model=model_con, inputs="x_com", outputs="y_com"), LambdaOp(lambda x: tf.split(x, 2, axis=0), inputs="y_com", outputs=["y_pred", "y_pred2"]), NTXentOp(arg1="y_pred", arg2="y_pred2", outputs=["NTXent", "logit", "label"]), UpdateOp(model=model_con, loss_name="NTXent") ]) # step 3: prepare estimator traces = [ Accuracy(true_key="label", pred_key="logit", mode="train", output_name="contrastive_accuracy"), ModelSaver(model=model_con, save_dir=save_dir), ] estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=epochs, traces=traces, max_train_steps_per_epoch=max_train_steps_per_epoch, monitor_names="contrastive_accuracy") estimator.fit() return model_con, model_finetune
def test_multi_input(self): to_gray = ToGray(inputs='x', outputs='x') output = to_gray.forward(data=self.multi_input, state={}) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output list length'): self.assertEqual(len(output), 2) for img_output in output: with self.subTest('Check output mask shape'): self.assertEqual(img_output.shape, self.multi_output_shape)
def pretrain_model(epochs, batch_size, train_steps_per_epoch, save_dir): train_data, test_data = load_data() pipeline = fe.Pipeline( train_data=train_data, batch_size=batch_size, ops=[ PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"), # augmentation 1 RandomCrop(32, 32, image_in="x", image_out="x_aug"), Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"), prob=0.5), Sometimes(ColorJitter(inputs="x_aug", outputs="x_aug", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), prob=0.8), Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2), Sometimes(GaussianBlur(inputs="x_aug", outputs="x_aug", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5), ChannelTranspose(inputs="x_aug", outputs="x_aug"), ToFloat(inputs="x_aug", outputs="x_aug"), # augmentation 2 RandomCrop(32, 32, image_in="x", image_out="x_aug2"), Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"), prob=0.5), Sometimes(ColorJitter(inputs="x_aug2", outputs="x_aug2", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), prob=0.8), Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2), Sometimes(GaussianBlur(inputs="x_aug2", outputs="x_aug2", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5), ChannelTranspose(inputs="x_aug2", outputs="x_aug2"), ToFloat(inputs="x_aug2", outputs="x_aug2") ]) model_con = fe.build(model_fn=lambda: ResNet9OneLayerHead(length=128), optimizer_fn="adam") network = fe.Network(ops=[ LambdaOp(lambda x, y: torch.cat([x, y], dim=0), inputs=["x_aug", "x_aug2"], outputs="x_com"), ModelOp(model=model_con, inputs="x_com", outputs="y_com"), LambdaOp(lambda x: torch.chunk(x, 2, dim=0), inputs="y_com", outputs=["y_pred", "y_pred2"], mode="train"), NTXentOp(arg1="y_pred", arg2="y_pred2", outputs=["NTXent", "logit", "label"], mode="train"), UpdateOp(model=model_con, loss_name="NTXent") ]) traces = [ Accuracy(true_key="label", pred_key="logit", mode="train", output_name="contrastive_accuracy"), ModelSaver(model=model_con, save_dir=save_dir) ] estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=epochs, traces=traces, train_steps_per_epoch=train_steps_per_epoch) estimator.fit() return model_con