def test_tf_static_superloss_categorical_ce(self): sl = SuperLoss(CrossEntropy(inputs=['y_pred', 'y'], outputs='ce')) sl.build(framework="tf", device=None) output = self.do_forward(sl, data=[self.tf_pred_cat, self.tf_true_cat], state=self.state) self.assertTrue(np.allclose(output.numpy(), -0.0026386082))
def test_tf_static_superloss_hinge(self): true = tf.constant([[-1, 1, 1, -1], [1, 1, 1, 1], [-1, -1, 1, -1], [1, -1, -1, -1]]) pred = tf.constant([[0.1, 0.9, 0.05, 0.05], [0.1, -0.2, 0.0, -0.7], [0.0, 0.15, 0.8, 0.05], [1.0, -1.0, -1.0, -1.0]]) sl = SuperLoss(Hinge(inputs=('x1', 'x2'), outputs='x')) sl.build('tf', None) output = self.do_forward(sl, data=[pred, true], state=self.state) self.assertTrue(np.allclose(output.numpy(), -0.072016776))
def test_torch_superloss_hinge(self): true = torch.tensor([[-1, 1, 1, -1], [1, 1, 1, 1], [-1, -1, 1, -1], [1, -1, -1, -1]]).to("cuda:0" if torch.cuda.is_available() else "cpu") pred = torch.tensor([[0.1, 0.9, 0.05, 0.05], [0.1, -0.2, 0.0, -0.7], [0.0, 0.15, 0.8, 0.05], [1.0, -1.0, -1.0, -1.0]]).to("cuda:0" if torch.cuda.is_available() else "cpu") sl = SuperLoss(Hinge(inputs=('x1', 'x2'), outputs='x')) sl.build('torch', "cuda:0" if torch.cuda.is_available() else "cpu") output = sl.forward(data=[pred, true], state=self.state) self.assertTrue(np.allclose(output.to("cpu").numpy(), -0.072016776))
def test_torch_superloss_binary_ce(self): sl = SuperLoss(CrossEntropy(inputs=['y_pred', 'y'], outputs='ce')) sl.build(framework="torch", device="cuda:0" if torch.cuda.is_available() else "cpu") output = sl.forward( data=[self.torch_pred_binary, self.torch_true_binary], state=self.state) self.assertTrue( np.allclose(output.detach().to("cpu").numpy(), -0.0026238672))
def instantiate_system(): system = sample_system_object() model = fe.build(model_fn=fe.architecture.pytorch.LeNet, optimizer_fn='adam', model_name='tf') system.network = fe.Network(ops=[ ModelOp(model=model, inputs="x_out", outputs="y_pred"), SuperLoss(CrossEntropy(inputs=['y_pred', 'y'], outputs='ce')) ]) return system
def get_estimator(epochs=50, batch_size=128, max_train_steps_per_epoch=None, max_eval_steps_per_epoch=None, save_dir=tempfile.mkdtemp()): # step 1 train_data, eval_data = cifair100.load_data() # Add label noise to simulate real-world labeling problems corrupt_dataset(train_data) test_data = eval_data.split(range(len(eval_data) // 2)) pipeline = fe.Pipeline( train_data=train_data, eval_data=eval_data, test_data=test_data, batch_size=batch_size, ops=[ Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)), PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"), RandomCrop(32, 32, image_in="x", image_out="x", mode="train"), Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")), CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1), ChannelTranspose(inputs="x", outputs="x") ]) # step 2 model = fe.build(model_fn=big_lenet, optimizer_fn='adam') network = fe.Network(ops=[ ModelOp(model=model, inputs="x", outputs="y_pred"), SuperLoss(CrossEntropy(inputs=("y_pred", "y"), outputs="ce"), output_confidence="confidence"), UpdateOp(model=model, loss_name="ce") ]) # step 3 traces = [ MCC(true_key="y", pred_key="y_pred"), BestModelSaver(model=model, save_dir=save_dir, metric="mcc", save_best_mode="max", load_best_final=True), LabelTracker(metric="confidence", label="data_labels", label_mapping={ "Normal": 0, "Corrupted": 1 }, mode="train", outputs="label_confidence"), ImageSaver(inputs="label_confidence", save_dir=save_dir, mode="train"), ] estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=epochs, traces=traces, max_train_steps_per_epoch=max_train_steps_per_epoch, max_eval_steps_per_epoch=max_eval_steps_per_epoch) return estimator
def test_tf_superloss_sparse_categorical_ce(self): sl = SuperLoss(CrossEntropy(inputs=['y_pred', 'y'], outputs='ce')) sl.build(framework="tf", device=None) output = sl.forward(data=[self.tf_pred_sparse, self.tf_true_sparse], state=self.state) self.assertTrue(np.allclose(output.numpy(), -0.024740249))