def test_default_input_fn_csv_bad_columns(): str_io = StringIO() csv_writer = csv.writer(str_io, delimiter=',') csv_writer.writerow([1, 2, 3]) csv_writer.writerow([1, 2, 3, 4]) with pytest.raises(ValueError): default_input_fn(str_io.getvalue(), content_types.CSV)
def test_default_input_fn_npy(tensor): stream = BytesIO() np.save(stream, tensor.cpu().numpy()) deserialized_np_array = default_input_fn(stream.getvalue(), content_types.NPY) assert deserialized_np_array.is_cuda == torch.cuda.is_available() assert torch.equal(tensor, deserialized_np_array)
def test_default_input_fn_csv(): array = [[1, 2, 3], [4, 5, 6]] str_io = StringIO() csv.writer(str_io, delimiter=',').writerows(array) deserialized_np_array = default_input_fn(str_io.getvalue(), content_types.CSV) tensor = torch.FloatTensor(array).to(device) assert torch.equal(tensor, deserialized_np_array) assert deserialized_np_array.is_cuda == torch.cuda.is_available()
def test_default_input_fn_bad_content_type(): with pytest.raises(encoders.UnsupportedFormatError): default_input_fn('', 'application/not_supported')
def test_default_input_fn_json(tensor): json_data = json.dumps(tensor.cpu().numpy().tolist()) deserialized_np_array = default_input_fn(json_data, content_types.JSON) assert deserialized_np_array.is_cuda == torch.cuda.is_available() assert torch.equal(tensor, deserialized_np_array)