def test_tf_model(self): model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam') model_saver = ModelSaver(model=model, save_dir=self.save_dir) model_saver.system = sample_system_object() model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) tf_model_path = os.path.join(self.save_dir, model_name + '.h5') with self.subTest('Check if model is saved'): self.assertTrue(os.path.exists(tf_model_path)) with self.subTest('Validate model weights'): m2 = fe.build(model_fn=one_layer_model_without_weights, optimizer_fn='adam') fe.backend.load_model(m2, tf_model_path) self.assertTrue(is_equal(m2.trainable_variables, model.trainable_variables))
def test_torch_model(self): model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam') model_saver = ModelSaver(model=model, save_dir=self.save_dir) model_saver.system = sample_system_object() model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) torch_model_path = os.path.join(self.save_dir, model_name + '.pt') if os.path.exists(torch_model_path): os.remove(torch_model_path) model_saver.on_epoch_end(data=Data()) with self.subTest('Check if model is saved'): self.assertTrue(os.path.exists(torch_model_path)) with self.subTest('Validate model weights'): m2 = fe.build(model_fn=MultiLayerTorchModelWithoutWeights, optimizer_fn='adam') fe.backend.load_model(m2, torch_model_path) self.assertTrue(is_equal(list(m2.parameters()), list(model.parameters())))
def get_estimator(epochs=50, batch_size=256, max_train_steps_per_epoch=None, save_dir=tempfile.mkdtemp()): train_data, _ = mnist.load_data() pipeline = fe.Pipeline( train_data=train_data, batch_size=batch_size, ops=[ ExpandDims(inputs="x", outputs="x"), Normalize(inputs="x", outputs="x", mean=1.0, std=1.0, max_pixel_value=127.5), LambdaOp(fn=lambda: np.random.normal(size=[100]).astype('float32'), outputs="z") ]) gen_model = fe.build(model_fn=generator, optimizer_fn=lambda: tf.optimizers.Adam(1e-4)) disc_model = fe.build(model_fn=discriminator, optimizer_fn=lambda: tf.optimizers.Adam(1e-4)) network = fe.Network(ops=[ ModelOp(model=gen_model, inputs="z", outputs="x_fake"), ModelOp(model=disc_model, inputs="x_fake", outputs="fake_score"), GLoss(inputs="fake_score", outputs="gloss"), UpdateOp(model=gen_model, loss_name="gloss"), ModelOp(inputs="x", model=disc_model, outputs="true_score"), DLoss(inputs=("true_score", "fake_score"), outputs="dloss"), UpdateOp(model=disc_model, loss_name="dloss") ]) estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=epochs, traces=ModelSaver(model=gen_model, save_dir=save_dir, frequency=5), max_train_steps_per_epoch=max_train_steps_per_epoch) return estimator
def test_max_to_keep_invalid_value(self): model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam') save_dir = tempfile.mkdtemp() with self.subTest('Check max_to_keep < 0'): with self.assertRaises(ValueError): ModelSaver(model=model, save_dir=save_dir, max_to_keep=-2)
def pretrain_model(epochs, batch_size, max_train_steps_per_epoch, save_dir): # step 1: prepare dataset train_data, test_data = load_data() pipeline = fe.Pipeline( train_data=train_data, batch_size=batch_size, ops=[ PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x"), # augmentation 1 RandomCrop(32, 32, image_in="x", image_out="x_aug"), Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"), prob=0.5), Sometimes( ColorJitter(inputs="x_aug", outputs="x_aug", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), prob=0.8), Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2), Sometimes(GaussianBlur(inputs="x_aug", outputs="x_aug", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5), ToFloat(inputs="x_aug", outputs="x_aug"), # augmentation 2 RandomCrop(32, 32, image_in="x", image_out="x_aug2"), Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"), prob=0.5), Sometimes( ColorJitter(inputs="x_aug2", outputs="x_aug2", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), prob=0.8), Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2), Sometimes(GaussianBlur(inputs="x_aug2", outputs="x_aug2", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5), ToFloat(inputs="x_aug2", outputs="x_aug2") ]) # step 2: prepare network model_con, model_finetune = fe.build(model_fn=ResNet9, optimizer_fn=["adam", "adam"]) network = fe.Network(ops=[ LambdaOp(lambda x, y: tf.concat([x, y], axis=0), inputs=["x_aug", "x_aug2"], outputs="x_com"), ModelOp(model=model_con, inputs="x_com", outputs="y_com"), LambdaOp(lambda x: tf.split(x, 2, axis=0), inputs="y_com", outputs=["y_pred", "y_pred2"]), NTXentOp(arg1="y_pred", arg2="y_pred2", outputs=["NTXent", "logit", "label"]), UpdateOp(model=model_con, loss_name="NTXent") ]) # step 3: prepare estimator traces = [ Accuracy(true_key="label", pred_key="logit", mode="train", output_name="contrastive_accuracy"), ModelSaver(model=model_con, save_dir=save_dir), ] estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=epochs, traces=traces, max_train_steps_per_epoch=max_train_steps_per_epoch, monitor_names="contrastive_accuracy") estimator.fit() return model_con, model_finetune
def get_estimator(batch_size=4, epochs=2, max_train_steps_per_epoch=None, log_steps=100, style_weight=5.0, content_weight=1.0, tv_weight=1e-4, save_dir=tempfile.mkdtemp(), style_img_path='Vassily_Kandinsky,_1913_-_Composition_7.jpg', data_dir=None): train_data, _ = mscoco.load_data(root_dir=data_dir, load_bboxes=False, load_masks=False, load_captions=False) device = "cuda" if torch.cuda.is_available() else "cpu" style_img = cv2.imread(style_img_path) assert style_img is not None, "cannot load the style image, please go to the folder with style image" style_img = cv2.resize(style_img, (256, 256)) style_img = (style_img.astype(np.float32) - 127.5) / 127.5 pipeline = fe.Pipeline( train_data=train_data, batch_size=batch_size, ops=[ ReadImage(inputs="image", outputs="image"), Normalize(inputs="image", outputs="image", mean=1.0, std=1.0, max_pixel_value=127.5), Resize(height=256, width=256, image_in="image", image_out="image"), LambdaOp(fn=lambda: style_img, outputs="style_image"), ChannelTranspose(inputs=["image", "style_image"], outputs=["image", "style_image"]) ]) model = fe.build(model_fn=StyleTransferNet, model_name="style_transfer_net", optimizer_fn=lambda x: torch.optim.Adam(x, lr=1e-3)) network = fe.Network(ops=[ ModelOp(inputs="image", model=model, outputs="image_out"), ExtractVGGFeatures(inputs="style_image", outputs="y_style", device=device), ExtractVGGFeatures(inputs="image", outputs="y_content", device=device), ExtractVGGFeatures(inputs="image_out", outputs="y_pred", device=device), StyleContentLoss(style_weight=style_weight, content_weight=content_weight, tv_weight=tv_weight, inputs=('y_pred', 'y_style', 'y_content', 'image_out'), outputs='loss'), UpdateOp(model=model, loss_name="loss") ]) estimator = fe.Estimator(network=network, pipeline=pipeline, traces=ModelSaver(model=model, save_dir=save_dir, frequency=1), epochs=epochs, max_train_steps_per_epoch=max_train_steps_per_epoch, log_steps=log_steps) return estimator
def test_max_to_keep_torch(self): save_dir = tempfile.mkdtemp() model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam') model_saver = ModelSaver(model=model, save_dir=save_dir, max_to_keep=2) model_saver.system = sample_system_object() model_saver.on_epoch_end(data=Data()) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) torch_model_path1 = os.path.join(save_dir, model_name + '.pt') model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) torch_model_path2 = os.path.join(save_dir, model_name + '.pt') with self.subTest('Check only two file are kept'): self.assertEqual(len(os.listdir(save_dir)), 2) with self.subTest('Check two latest model are kept'): self.assertTrue(os.path.exists(torch_model_path1)) self.assertTrue(os.path.exists(torch_model_path2))
def get_estimator(target_size=128, epochs=55, save_dir=tempfile.mkdtemp(), max_train_steps_per_epoch=None, data_dir=None): # assert growth parameters num_grow = np.log2(target_size) - 2 assert num_grow >= 1 and num_grow % 1 == 0, "need exponential of 2 and greater than 8 as target size" num_phases = int(2 * num_grow + 1) assert epochs % num_phases == 0, "epoch must be multiple of {} for size {}".format( num_phases, target_size) num_grow, phase_length = int(num_grow), int(epochs / num_phases) event_epoch = [1, 1 + phase_length] + [ phase_length * (2 * i + 1) + 1 for i in range(1, num_grow) ] event_size = [4] + [2**(i + 3) for i in range(num_grow)] # set up data schedules dataset = nih_chestxray.load_data(root_dir=data_dir) resize_map = { epoch: Resize(image_in="x", image_out="x", height=size, width=size) for (epoch, size) in zip(event_epoch, event_size) } resize_low_res_map1 = { epoch: Resize(image_in="x", image_out="x_low_res", height=size // 2, width=size // 2) for (epoch, size) in zip(event_epoch, event_size) } resize_low_res_map2 = { epoch: Resize(image_in="x_low_res", image_out="x_low_res", height=size, width=size) for (epoch, size) in zip(event_epoch, event_size) } batch_size_map = { epoch: 512 // size * get_num_devices() if size <= 128 else 4 * get_num_devices() for (epoch, size) in zip(event_epoch, event_size) } batch_scheduler = EpochScheduler(epoch_dict=batch_size_map) pipeline = fe.Pipeline( batch_size=batch_scheduler, train_data=dataset, drop_last=True, ops=[ ReadImage(inputs="x", outputs="x", color_flag='gray'), EpochScheduler(epoch_dict=resize_map), EpochScheduler(epoch_dict=resize_low_res_map1), EpochScheduler(epoch_dict=resize_low_res_map2), Normalize(inputs=["x", "x_low_res"], outputs=["x", "x_low_res"], mean=1.0, std=1.0, max_pixel_value=127.5), LambdaOp(fn=lambda: np.random.normal(size=[512]).astype('float32'), outputs="z") ]) # now model schedule fade_in_alpha = tf.Variable(initial_value=1.0, dtype='float32', trainable=False) d_models = fe.build( model_fn=lambda: build_D(fade_in_alpha, target_resolution=int(np.log2(target_size)), num_channels=1), optimizer_fn=[ lambda: Adam(0.001, beta_1=0.0, beta_2=0.99, epsilon=1e-8) ] * len(event_size), model_name=["d_{}".format(size) for size in event_size]) g_models = fe.build( model_fn=lambda: build_G(fade_in_alpha, target_resolution=int(np.log2(target_size)), num_channels=1), optimizer_fn=[ lambda: Adam(0.001, beta_1=0.0, beta_2=0.99, epsilon=1e-8) ] * len(event_size) + [None], model_name=["g_{}".format(size) for size in event_size] + ["G"]) fake_img_map = { epoch: ModelOp(inputs="z", outputs="x_fake", model=model) for (epoch, model) in zip(event_epoch, g_models[:-1]) } fake_score_map = { epoch: ModelOp(inputs="x_fake", outputs="fake_score", model=model) for (epoch, model) in zip(event_epoch, d_models) } real_score_map = { epoch: ModelOp(inputs="x_blend", outputs="real_score", model=model) for (epoch, model) in zip(event_epoch, d_models) } interp_score_map = { epoch: ModelOp(inputs="x_interp", outputs="interp_score", model=model) for (epoch, model) in zip(event_epoch, d_models) } g_update_map = { epoch: UpdateOp(loss_name="gloss", model=model) for (epoch, model) in zip(event_epoch, g_models[:-1]) } d_update_map = { epoch: UpdateOp(loss_name="dloss", model=model) for (epoch, model) in zip(event_epoch, d_models) } network = fe.Network(ops=[ EpochScheduler(fake_img_map), EpochScheduler(fake_score_map), ImageBlender( alpha=fade_in_alpha, inputs=("x", "x_low_res"), outputs="x_blend"), EpochScheduler(real_score_map), Interpolate(inputs=("x_fake", "x"), outputs="x_interp"), EpochScheduler(interp_score_map), GradientPenalty(inputs=("x_interp", "interp_score"), outputs="gp"), GLoss(inputs="fake_score", outputs="gloss"), DLoss(inputs=("real_score", "fake_score", "gp"), outputs="dloss"), EpochScheduler(g_update_map), EpochScheduler(d_update_map) ]) traces = [ AlphaController(alpha=fade_in_alpha, fade_start_epochs=event_epoch[1:], duration=phase_length, batch_scheduler=batch_scheduler, num_examples=len(dataset)), ModelSaver(model=g_models[-1], save_dir=save_dir, frequency=phase_length), ImageSaving(epoch_model_map={ epoch - 1: model for (epoch, model) in zip(event_epoch[1:] + [epochs + 1], g_models[:-1]) }, save_dir=save_dir) ] estimator = fe.Estimator( pipeline=pipeline, network=network, epochs=epochs, traces=traces, max_train_steps_per_epoch=max_train_steps_per_epoch) return estimator
def get_estimator(weight=10.0, epochs=200, batch_size=1, max_train_steps_per_epoch=None, save_dir=tempfile.mkdtemp(), data_dir=None): train_data, _ = load_data(batch_size=batch_size, root_dir=data_dir) device = "cuda" if torch.cuda.is_available() else "cpu" pipeline = fe.Pipeline(train_data=train_data, ops=[ ReadImage(inputs=["A", "B"], outputs=["A", "B"]), Normalize(inputs=["A", "B"], outputs=["real_A", "real_B"], mean=1.0, std=1.0, max_pixel_value=127.5), Resize(height=286, width=286, image_in="real_A", image_out="real_A", mode="train"), RandomCrop(height=256, width=256, image_in="real_A", image_out="real_A", mode="train"), Resize(height=286, width=286, image_in="real_B", image_out="real_B", mode="train"), RandomCrop(height=256, width=256, image_in="real_B", image_out="real_B", mode="train"), Sometimes( HorizontalFlip(image_in="real_A", image_out="real_A", mode="train")), Sometimes( HorizontalFlip(image_in="real_B", image_out="real_B", mode="train")), ChannelTranspose(inputs=["real_A", "real_B"], outputs=["real_A", "real_B"]) ]) g_AtoB = fe.build(model_fn=Generator, model_name="g_AtoB", optimizer_fn=lambda x: torch.optim.Adam( x, lr=2e-4, betas=(0.5, 0.999))) g_BtoA = fe.build(model_fn=Generator, model_name="g_BtoA", optimizer_fn=lambda x: torch.optim.Adam( x, lr=2e-4, betas=(0.5, 0.999))) d_A = fe.build(model_fn=Discriminator, model_name="d_A", optimizer_fn=lambda x: torch.optim.Adam( x, lr=2e-4, betas=(0.5, 0.999))) d_B = fe.build(model_fn=Discriminator, model_name="d_B", optimizer_fn=lambda x: torch.optim.Adam( x, lr=2e-4, betas=(0.5, 0.999))) network = fe.Network(ops=[ ModelOp(inputs="real_A", model=g_AtoB, outputs="fake_B"), ModelOp(inputs="real_B", model=g_BtoA, outputs="fake_A"), Buffer(image_in="fake_A", image_out="buffer_fake_A"), Buffer(image_in="fake_B", image_out="buffer_fake_B"), ModelOp(inputs="real_A", model=d_A, outputs="d_real_A"), ModelOp(inputs="fake_A", model=d_A, outputs="d_fake_A"), ModelOp(inputs="buffer_fake_A", model=d_A, outputs="buffer_d_fake_A"), ModelOp(inputs="real_B", model=d_B, outputs="d_real_B"), ModelOp(inputs="fake_B", model=d_B, outputs="d_fake_B"), ModelOp(inputs="buffer_fake_B", model=d_B, outputs="buffer_d_fake_B"), ModelOp(inputs="real_A", model=g_BtoA, outputs="same_A"), ModelOp(inputs="fake_B", model=g_BtoA, outputs="cycled_A"), ModelOp(inputs="real_B", model=g_AtoB, outputs="same_B"), ModelOp(inputs="fake_A", model=g_AtoB, outputs="cycled_B"), GLoss(inputs=("real_A", "d_fake_B", "cycled_A", "same_A"), weight=weight, device=device, outputs="g_AtoB_loss"), GLoss(inputs=("real_B", "d_fake_A", "cycled_B", "same_B"), weight=weight, device=device, outputs="g_BtoA_loss"), DLoss(inputs=("d_real_A", "buffer_d_fake_A"), outputs="d_A_loss", device=device), DLoss(inputs=("d_real_B", "buffer_d_fake_B"), outputs="d_B_loss", device=device), UpdateOp(model=g_AtoB, loss_name="g_AtoB_loss"), UpdateOp(model=g_BtoA, loss_name="g_BtoA_loss"), UpdateOp(model=d_A, loss_name="d_A_loss"), UpdateOp(model=d_B, loss_name="d_B_loss") ]) traces = [ ModelSaver(model=g_AtoB, save_dir=save_dir, frequency=10), ModelSaver(model=g_BtoA, save_dir=save_dir, frequency=10), LRScheduler(model=g_AtoB, lr_fn=lr_schedule), LRScheduler(model=g_BtoA, lr_fn=lr_schedule), LRScheduler(model=d_A, lr_fn=lr_schedule), LRScheduler(model=d_B, lr_fn=lr_schedule) ] estimator = fe.Estimator( network=network, pipeline=pipeline, epochs=epochs, traces=traces, max_train_steps_per_epoch=max_train_steps_per_epoch) return estimator
def get_estimator(weight=10.0, epochs=200, batch_size=1, train_steps_per_epoch=None, save_dir=tempfile.mkdtemp(), data_dir=None): train_data, _ = load_data(batch_size=batch_size, root_dir=data_dir) pipeline = fe.Pipeline( train_data=train_data, ops=[ ReadImage(inputs=["A", "B"], outputs=["A", "B"]), Normalize(inputs=["A", "B"], outputs=["real_A", "real_B"], mean=1.0, std=1.0, max_pixel_value=127.5), Resize(height=286, width=286, image_in="real_A", image_out="real_A", mode="train"), RandomCrop(height=256, width=256, image_in="real_A", image_out="real_A", mode="train"), Resize(height=286, width=286, image_in="real_B", image_out="real_B", mode="train"), RandomCrop(height=256, width=256, image_in="real_B", image_out="real_B", mode="train"), Sometimes(HorizontalFlip(image_in="real_A", image_out="real_A", mode="train")), Sometimes(HorizontalFlip(image_in="real_B", image_out="real_B", mode="train")), PlaceholderOp(outputs=("index_A", "buffer_A")), PlaceholderOp(outputs=("index_B", "buffer_B")) ]) g_AtoB = fe.build(model_fn=build_generator, model_name="g_AtoB", optimizer_fn=lambda: tf.optimizers.Adam(2e-4, 0.5)) g_BtoA = fe.build(model_fn=build_generator, model_name="g_BtoA", optimizer_fn=lambda: tf.optimizers.Adam(2e-4, 0.5)) d_A = fe.build(model_fn=build_discriminator, model_name="d_A", optimizer_fn=lambda: tf.optimizers.Adam(2e-4, 0.5)) d_B = fe.build(model_fn=build_discriminator, model_name="d_B", optimizer_fn=lambda: tf.optimizers.Adam(2e-4, 0.5)) network = fe.Network(ops=[ ModelOp(inputs="real_A", model=g_AtoB, outputs="fake_B"), ModelOp(inputs="real_B", model=g_BtoA, outputs="fake_A"), Buffer(image_in="fake_A", buffer_in="buffer_A", index_in="index_A", image_out="buffer_fake_A"), Buffer(image_in="fake_B", buffer_in="buffer_B", index_in="index_B", image_out="buffer_fake_B"), ModelOp(inputs="real_A", model=d_A, outputs="d_real_A"), ModelOp(inputs="fake_A", model=d_A, outputs="d_fake_A"), ModelOp(inputs="buffer_fake_A", model=d_A, outputs="buffer_d_fake_A"), ModelOp(inputs="real_B", model=d_B, outputs="d_real_B"), ModelOp(inputs="fake_B", model=d_B, outputs="d_fake_B"), ModelOp(inputs="buffer_fake_B", model=d_B, outputs="buffer_d_fake_B"), ModelOp(inputs="real_A", model=g_BtoA, outputs="same_A"), ModelOp(inputs="fake_B", model=g_BtoA, outputs="cycled_A"), ModelOp(inputs="real_B", model=g_AtoB, outputs="same_B"), ModelOp(inputs="fake_A", model=g_AtoB, outputs="cycled_B"), GLoss(inputs=("real_A", "d_fake_B", "cycled_A", "same_A"), weight=weight, outputs="g_AtoB_loss"), GLoss(inputs=("real_B", "d_fake_A", "cycled_B", "same_B"), weight=weight, outputs="g_BtoA_loss"), DLoss(inputs=("d_real_A", "buffer_d_fake_A"), outputs="d_A_loss"), DLoss(inputs=("d_real_B", "buffer_d_fake_B"), outputs="d_B_loss"), UpdateOp(model=g_AtoB, loss_name="g_AtoB_loss"), UpdateOp(model=g_BtoA, loss_name="g_BtoA_loss"), UpdateOp(model=d_A, loss_name="d_A_loss"), UpdateOp(model=d_B, loss_name="d_B_loss") ]) traces = [ BufferUpdate(input_name="fake_A", buffer_size=50, batch_size=batch_size, mode="train", output_name=["buffer_A", "index_A"]), BufferUpdate(input_name="fake_B", buffer_size=50, batch_size=batch_size, mode="train", output_name=["buffer_B", "index_B"]), ModelSaver(model=g_AtoB, save_dir=save_dir, frequency=5), ModelSaver(model=g_BtoA, save_dir=save_dir, frequency=5), LRScheduler(model=g_AtoB, lr_fn=lr_schedule), LRScheduler(model=g_BtoA, lr_fn=lr_schedule), LRScheduler(model=d_A, lr_fn=lr_schedule), LRScheduler(model=d_B, lr_fn=lr_schedule) ] estimator = fe.Estimator(network=network, pipeline=pipeline, epochs=epochs, traces=traces, train_steps_per_epoch=train_steps_per_epoch) return estimator
def test_torch_architecture_save(self): model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam') save_dir = tempfile.mkdtemp() with self.assertRaises(ValueError): ModelSaver(model=model, save_dir=save_dir, save_architecture=True)
def test_max_to_keep_tf_architecture(self): save_dir = tempfile.mkdtemp() model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam') model_saver = ModelSaver(model=model, save_dir=save_dir, max_to_keep=2, save_architecture=True) model_saver.system = sample_system_object() model_saver.on_epoch_end(data=Data()) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) tf_model_path1 = os.path.join(save_dir, model_name + '.h5') tf_architecture_path1 = os.path.join(save_dir, model_name) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) tf_model_path2 = os.path.join(save_dir, model_name + '.h5') tf_architecture_path2 = os.path.join(save_dir, model_name) with self.subTest('Check only four files are kept'): self.assertEqual(len(os.listdir(save_dir)), 4) with self.subTest('Check two latest models are kept'): self.assertTrue(os.path.exists(tf_model_path1)) self.assertTrue(os.path.exists(tf_model_path2)) self.assertTrue(os.path.exists(tf_architecture_path1)) self.assertTrue(os.path.isdir(tf_architecture_path1)) self.assertTrue(os.path.exists(tf_architecture_path2)) self.assertTrue(os.path.isdir(tf_architecture_path2))
def pretrain_model(epochs, batch_size, train_steps_per_epoch, save_dir): train_data, test_data = load_data() pipeline = fe.Pipeline( train_data=train_data, batch_size=batch_size, ops=[ PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"), # augmentation 1 RandomCrop(32, 32, image_in="x", image_out="x_aug"), Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"), prob=0.5), Sometimes(ColorJitter(inputs="x_aug", outputs="x_aug", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), prob=0.8), Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2), Sometimes(GaussianBlur(inputs="x_aug", outputs="x_aug", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5), ChannelTranspose(inputs="x_aug", outputs="x_aug"), ToFloat(inputs="x_aug", outputs="x_aug"), # augmentation 2 RandomCrop(32, 32, image_in="x", image_out="x_aug2"), Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"), prob=0.5), Sometimes(ColorJitter(inputs="x_aug2", outputs="x_aug2", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), prob=0.8), Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2), Sometimes(GaussianBlur(inputs="x_aug2", outputs="x_aug2", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5), ChannelTranspose(inputs="x_aug2", outputs="x_aug2"), ToFloat(inputs="x_aug2", outputs="x_aug2") ]) model_con = fe.build(model_fn=lambda: ResNet9OneLayerHead(length=128), optimizer_fn="adam") network = fe.Network(ops=[ LambdaOp(lambda x, y: torch.cat([x, y], dim=0), inputs=["x_aug", "x_aug2"], outputs="x_com"), ModelOp(model=model_con, inputs="x_com", outputs="y_com"), LambdaOp(lambda x: torch.chunk(x, 2, dim=0), inputs="y_com", outputs=["y_pred", "y_pred2"], mode="train"), NTXentOp(arg1="y_pred", arg2="y_pred2", outputs=["NTXent", "logit", "label"], mode="train"), UpdateOp(model=model_con, loss_name="NTXent") ]) traces = [ Accuracy(true_key="label", pred_key="logit", mode="train", output_name="contrastive_accuracy"), ModelSaver(model=model_con, save_dir=save_dir) ] estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=epochs, traces=traces, train_steps_per_epoch=train_steps_per_epoch) estimator.fit() return model_con