def test_multi_batch_data(self, mock_board, _): mock_board.return_value = Mock() mock_board.return_value.add_embedding = Mock() state = { torchbearer.X: torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0 } tboard = TensorBoardProjector(num_images=45, avg_data_channels=False, write_data=True, write_features=False) tboard.on_start(state) for i in range(3): state[torchbearer.BATCH] = i tboard.on_step_validation(state) mock_board.return_value.add_embedding.assert_called_once_with( ANY, metadata=ANY, label_img=ANY, tag='data', global_step=-1) self.assertTrue(mock_board.return_value.add_embedding.call_args[0] [0].size() == torch.Size([45, 300])) self.assertTrue(mock_board.return_value.add_embedding.call_args[1] ['metadata'].size() == torch.Size([45])) self.assertTrue(mock_board.return_value.add_embedding.call_args[1] ['label_img'].size() == torch.Size([45, 3, 10, 10])) tboard.on_end({})
def test_simple_case(self, mock_board, _): mock_board.return_value = Mock() mock_board.return_value.add_embedding = Mock() state = { torchbearer.X: torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0 } tboard = TensorBoardProjector(num_images=18, avg_data_channels=False, write_data=False, features_key=torchbearer.Y_TRUE) tboard.on_start(state) tboard.on_step_validation(state) mock_board.return_value.add_embedding.assert_called_once_with( ANY, metadata=ANY, label_img=ANY, tag='features', global_step=0) self.assertTrue( mock_board.return_value.add_embedding.call_args[0][0].size() == state[torchbearer.Y_TRUE].unsqueeze(1).size()) self.assertTrue( mock_board.return_value.add_embedding.call_args[1] ['metadata'].size() == state[torchbearer.Y_TRUE].size()) self.assertTrue(mock_board.return_value.add_embedding.call_args[1] ['label_img'].size() == state[torchbearer.X].size()) tboard.on_end(state)
def test_multi_epoch(self, mock_board): mock_board.return_value = Mock() mock_board.return_value.add_embedding = Mock() state = {torchbearer.X: torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0} tboard = TensorBoardProjector(num_images=18, avg_data_channels=False, write_data=False, features_key=torchbearer.Y_TRUE) tboard.on_start(state) tboard.on_step_validation(state) mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='features', global_step=0) self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == state[torchbearer.Y_TRUE].unsqueeze(1).size()) self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == state[torchbearer.Y_TRUE].size()) self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == state[torchbearer.X].size()) tboard.on_end_epoch({}) mock_board.return_value.add_embedding.reset_mock() tboard.on_step_validation(state) mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='features', global_step=0) self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == state[torchbearer.Y_TRUE].unsqueeze(1).size()) self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == state[torchbearer.Y_TRUE].size()) self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == state[torchbearer.X].size())
def test_no_channels(self, mock_board): mock_board.return_value = Mock() mock_board.return_value.add_embedding = Mock() state = {torchbearer.X: torch.ones(18, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0} tboard = TensorBoardProjector(num_images=18, avg_data_channels=False, write_data=True, write_features=False) tboard.on_start(state) tboard.on_step_validation(state) mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='data', global_step=-1) self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == torch.Size([18, 100])) self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == state[torchbearer.Y_TRUE].size()) self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == torch.Size([18, 1, 10, 10]))