def test_rescaling_generate_unknown(proposal, x): """Test the rescaling method with with an unknown method for generate""" proposal._base_rescale = MagicMock(return_value=[x, np.ones(x.size)]) with pytest.raises(RuntimeError) as excinfo: AugmentedFlowProposal._augmented_rescale(proposal, x, generate_augment='ones') assert 'Unknown method' in str(excinfo.value)
def test_rescaling(mock_zeros, mock_randn, proposal, x, generate): """Test the rescaling method""" proposal._base_rescale = MagicMock(return_value=[x, np.ones(x.size)]) proposal.augment_names = ['e_1'] proposal.augment_dims = 1 AugmentedFlowProposal._augmented_rescale(proposal, x, generate_augment=generate, test=True) proposal._base_rescale.assert_called_once_with(x, compute_radius=False, test=True) if generate == 'zeroes': mock_zeros.assert_called_once_with(x.size) else: mock_randn.assert_called_once_with(x.size)
def test_rescaling_generate_none(mock_zeros, mock_randn, proposal, x, compute_radius): """Test the rescaling method with generate_augment=None""" proposal._base_rescale = MagicMock(return_value=[x, np.ones(x.size)]) proposal.augment_names = ['e_1'] proposal.augment_dims = 1 proposal.generate_augment = 'zeros' AugmentedFlowProposal._augmented_rescale(proposal, x, generate_augment=None, test=True, compute_radius=compute_radius) proposal._base_rescale.assert_called_once_with( x, compute_radius=compute_radius, test=True) if not compute_radius: mock_randn.assert_called_once_with(x.size) else: mock_zeros.assert_called_once_with(x.size)