def test_delete_scheme(self): """ Test for delete view """ del_url = '/s/delete/{}/' response = self.client.post(del_url.format(random.randint(1, 100))) # check for delete non existent scheme self.assertEqual(response.status_code, 404) scheme = ShamirSS(**self.scheme_data) scheme.save() # check for wrong http method response = self.client.get(del_url.format(scheme.id)) self.assertEqual(response.status_code, 405) # link a document self.document.scheme = scheme self.document.save() # check correct redirection to delete_related expected_url = '/s/delete_related/{}/'.format(scheme.id) response = self.client.post(del_url.format(scheme.id), follow=True) self.assertRedirects(response, expected_url=expected_url, status_code=302, target_status_code=200) # unlink document self.document.scheme = None self.document.save() # check correct redirection after successful delete response = self.client.post(del_url.format(scheme.id), follow=True) self.assertRedirects(response, expected_url='/s/', status_code=302, target_status_code=200) self.assertRaises(ObjectDoesNotExist, lambda: ShamirSS.objects.get(pk=scheme.id))
def test_refresh_scheme(self): """ Test for refresh scheme """ refresh_url = '/s/refresh/{}/' response = self.client.post(refresh_url.format(random.randint(1, 100))) # check for delete non existent scheme self.assertEqual(response.status_code, 404) scheme = ShamirSS(**self.scheme_data) shares = scheme.get_shares() scheme.save() # check for wrong http method error response = self.client.get(refresh_url.format(scheme.id)) self.assertEqual(response.status_code, 405) # link a document self.document.scheme = scheme self.document.save() # check correct redirection to delete_related expected_url = '/s/delete_related/{}/'.format(scheme.id) response = self.client.post(refresh_url.format(scheme.id), follow=True) self.assertRedirects(response, expected_url=expected_url, status_code=302, target_status_code=200) # unlink document self.document.scheme = None self.document.save() # check correct shares refresh response = self.client.post(refresh_url.format(scheme.id), follow=True) self.assertEqual(response.status_code, 200) refreshed_shares = response.context['shares'] # check number of shares generated self.assertEqual(scheme.n, len(refreshed_shares)) # check generated shares different from the previous self.assertTrue(self.check_shares(shares, refreshed_shares))
def test_encrypt(self): """ Test encrypt view """ enc_url = '/s/encrypt/{}/{}/' # check non existent document and scheme response = self.client.get( enc_url.format(random.randint(2, 100), random.randint(1, 100))) self.assertEqual(response.status_code, 404) response = self.client.post( enc_url.format(random.randint(2, 100), random.randint(1, 100))) self.assertEqual(response.status_code, 404) scheme = ShamirSS(**self.scheme_data) shares = scheme.get_shares() scheme.save() # check get on existing scheme and document response = self.client.get(enc_url.format(self.document.id, scheme.id)) self.assertEqual(response.status_code, 200) # check successful redirect after encryption random_shares = self._pick_k_random_values(shares, scheme.k) post_data = { 'share_' + str(share[0]): share[1] for share in random_shares } post_data['scheme'] = scheme.id expected_url = '/folder/{}/'.format(self.document.folder.id) response = self.client.post(enc_url.format(self.document.id, scheme.id), post_data, follow=True) self.assertRedirects(response, expected_url=expected_url, status_code=302, target_status_code=200) # check document model changes self.document.refresh_from_db() self.assertTrue(os.path.isfile(self.document.file_path())) self.assertIsNotNone(self.document.scheme) self.assertEqual(self.document.filename(), self.TEST_FILE_NAME + '.enc') # check document already encrypted response = self.client.post(enc_url.format(self.document.id, scheme.id), post_data, follow=True) self.assertFormError(response, 'form', None, 'Document already encrypted')
def test_delete_related(self): """ Test for delete_related view """ del_url = '/s/delete_related/{}/' response = self.client.get(del_url.format(random.randint(1, 100))) # check for non existent scheme self.assertEqual(response.status_code, 404) scheme = ShamirSS(**self.scheme_data) scheme.save() # check for scheme without related documents response = self.client.get(del_url.format(scheme.id)) self.assertEqual(response.status_code, 404) # link a document self.document.scheme = scheme self.document.save() # check for scheme with related documents response = self.client.get(del_url.format(scheme.id)) self.assertEqual(response.status_code, 200) self.assertEqual(len(response.context['documents']), 1)
def test_index(self): """ Test for index view """ response = self.client.get('/s/') self.assertEqual(response.status_code, 200) # check for empty list if no scheme is created self.assertEqual(len(response.context['schemes']), 0) scheme = ShamirSS(**self.scheme_data) scheme.save() response = self.client.get('/s/') # check for created scheme self.assertEqual(response.status_code, 200) self.assertEqual(len(response.context['schemes']), 1) # check for correct scheme fields self.assertEqual(response.context['schemes'][0].name, scheme.name) self.assertEqual(response.context['schemes'][0].mers_exp, scheme.mers_exp) self.assertEqual(response.context['schemes'][0].k, scheme.k) self.assertEqual(response.context['schemes'][0].n, scheme.n)
def test_decrypt(self): dec_url = '/s/decrypt/{}/' # check non existent document response = self.client.get(dec_url.format(random.randint(2, 100))) self.assertEqual(response.status_code, 404) response = self.client.post(dec_url.format(random.randint(2, 100))) self.assertEqual(response.status_code, 404) # check get on plaintext document response = self.client.get(dec_url.format(self.document.id)) self.assertEqual(response.status_code, 404) # check post on plaintext document response = self.client.post(dec_url.format(self.document.id)) self.assertEqual(response.status_code, 404) # encrypt file scheme = ShamirSS(**self.scheme_data) shares = scheme.get_shares() scheme.save() enc_file_path = scheme.encrypt_file(self.document.file_path(), shares) os.remove(self.document.file_path()) self.document.file.name = enc_file_path self.document.scheme = scheme self.document.save() # check successful get response = self.client.get(dec_url.format(self.document.id)) self.assertEqual(response.status_code, 200) # check successful redirect after decryption random_shares = self._pick_k_random_values(shares, scheme.k) post_data = { 'share_' + str(share[0]): share[1] for share in random_shares } post_data['scheme'] = scheme.id expected_url = '/folder/{}/'.format(self.document.folder.id) response = self.client.post(dec_url.format(self.document.id), post_data, follow=True) self.assertRedirects(response, expected_url=expected_url, status_code=302, target_status_code=200) self.document.refresh_from_db() self.assertTrue(os.path.isfile(self.document.file_path())) self.assertIsNone(self.document.scheme) self.assertEqual(self.document.filename(), self.TEST_FILE_NAME)
def setUp(self): self.form_data = {'name': 'test', 'mers_exp': 107, 'k': 4, 'n': 18} self.scheme = ShamirSS(**self.form_data)
class ShamirSSTestCase(TestCase): """ Test for shared secret Model """ def setUp(self): self.form_data = {'name': 'test', 'mers_exp': 107, 'k': 4, 'n': 18} self.scheme = ShamirSS(**self.form_data) @classmethod def setUpClass(cls): super().setUpClass() def test_difference(self): """ Test difference method works """ self.assertTrue(self.scheme.difference() == 14) def test_scheme_fileds(self): """ Test for correct fields validation with SSForm """ valid_form = SSForm(data=self.form_data) self.assertTrue(valid_form.is_valid()) # Test illegal value for mers_exp self.form_data['mers_exp'] = 3 invalid_form = SSForm(data=self.form_data) self.assertFalse(invalid_form.is_valid()) # Test illegal value for k self.form_data['mers_exp'] = 107 self.form_data['k'] = self.scheme.MAX_N + 1 invalid_form = SSForm(data=self.form_data) self.assertFalse(invalid_form.is_valid()) # Test illegal value for n self.form_data['k'] = 4 self.form_data['n'] = self.scheme.MAX_N + 1 invalid_form = SSForm(data=self.form_data) self.assertFalse(invalid_form.is_valid()) # Test illegal value for n < k self.form_data['n'] = 3 invalid_form = SSForm(data=self.form_data) self.assertFalse(invalid_form.is_valid()) self.form_data['n'] = 18 def test_scheme_correctness(self): """ Test for successful shares generation and secret recovery """ # check correct base64 encoding-decoding random_int = random.randint(10000000000, 100000000000) enc_dec = self.scheme.decode_shares( self.scheme.encode_shares([(0, random_int)])) self.assertTrue(random_int == enc_dec[0][1]) # check all shares generated correctly shares = self.scheme.get_shares() encoded_secret = self.scheme.secret self.assertTrue(len(shares) == self.scheme.n) # check hashed secret is correct picking k random shares rnd_shares = self._pick_k_random_values(shares, self.scheme.k) rec_secret = self.scheme.get_secret( self.scheme.decode_shares(rnd_shares)) self.assertTrue(hashers.check_password(str(rec_secret), encoded_secret)) # check hashed secret is correct picking n random shares secret_all = self.scheme.get_secret(self.scheme.decode_shares(shares)) self.assertTrue(hashers.check_password(str(secret_all), encoded_secret)) # check value error if lower than k shares provided rnd_shares_2 = self._pick_k_random_values(shares, self.scheme.k - 1) self.assertRaises( ValueError, lambda: self.scheme.get_secret( self.scheme.decode_shares(rnd_shares_2))) # check for wrong shares rnd_shares_3 = self._pick_k_random_values(shares, self.scheme.k) rnd_shares_3[0] = (rnd_shares_3[0][0], rnd_shares_3[1][1]) wrong_secret = self.scheme.get_secret( self.scheme.decode_shares(rnd_shares_3)) self.assertFalse( hashers.check_password(str(wrong_secret), encoded_secret)) def test_file_encryption_decryption(self): """ Test successful file encryption and decryption """ # create two test files with same content file_name_1 = settings.MEDIA_ROOT + 'test_file_1.txt' file_name_2 = settings.MEDIA_ROOT + 'test_file_2.txt' content = 'some string just to fill this file up\n\n' test_file_1 = open(file_name_1, 'w+') test_file_1.write(content) test_file_1.close() test_file_2 = open(file_name_2, 'w+') test_file_2.write(content) test_file_2.close() # create shares for the scheme shares = self.scheme.get_shares() # encrypt/decrypt test file 1 enc_dec_test_file_1 = self.scheme.decrypt_file( settings.MEDIA_ROOT + self.scheme.encrypt_file(file_name_1, shares), shares) # encrypt/decrypt test file 2 enc_dec_test_file_2 = self.scheme.decrypt_file( settings.MEDIA_ROOT + self.scheme.encrypt_file(file_name_2, shares), shares) # create hashes of the two files hash_1 = self.hash_file(settings.MEDIA_ROOT + enc_dec_test_file_1) hash_2 = self.hash_file(settings.MEDIA_ROOT + enc_dec_test_file_2) # compare hashes self.assertTrue(hash_1 == hash_2) # remove encrypted files os.remove(settings.MEDIA_ROOT + enc_dec_test_file_1 + '.enc') os.remove(settings.MEDIA_ROOT + enc_dec_test_file_2 + '.enc') # test encryption with files having different content with open(file_name_2, 'a') as file: file.write('this make file 2 different\n') # encrypt/decrypt test file 1 enc_dec_test_file_1 = self.scheme.decrypt_file( settings.MEDIA_ROOT + self.scheme.encrypt_file(file_name_1, shares), shares) # encrypt/decrypt test file 2 enc_dec_test_file_2 = self.scheme.decrypt_file( settings.MEDIA_ROOT + self.scheme.encrypt_file(file_name_2, shares), shares) # create hashes of the two files hash_1 = self.hash_file(settings.MEDIA_ROOT + enc_dec_test_file_1) hash_2 = self.hash_file(settings.MEDIA_ROOT + enc_dec_test_file_2) # compare hashes self.assertTrue(hash_1 != hash_2) # remove test files os.remove(settings.MEDIA_ROOT + enc_dec_test_file_1 + '.enc') os.remove(settings.MEDIA_ROOT + enc_dec_test_file_2 + '.enc') os.remove(settings.MEDIA_ROOT + enc_dec_test_file_1) os.remove(settings.MEDIA_ROOT + enc_dec_test_file_2) def hash_file(self, file): """ return sha1 hash of a file """ blocksize = 65536 hasher = hashlib.sha1() with open(file, 'rb') as afile: buf = afile.read(blocksize) while len(buf) > 0: hasher.update(buf) buf = afile.read(blocksize) return hasher.hexdigest() def _pick_k_random_values(self, l, k): """ select k distinct random values from l """ s = set() while len(s) != k: s.add(random.choice(l)) return list(s)