def test_pagination(self, setup_es): """Tests the pagination.""" total_records = 9 page_size = 2 ids = sorted((uuid4() for _ in range(total_records))) name = 'test record' CompanyFactory.create_batch( len(ids), id=factory.Iterator(ids), name=name, trading_names=[], ) setup_es.indices.refresh() url = reverse('api-v3:search:basic') for page in range((len(ids) + page_size - 1) // page_size): response = self.api_client.get( url, data={ 'term': name, 'entity': 'company', 'offset': page * page_size, 'limit': page_size, }, ) assert response.status_code == status.HTTP_200_OK start = page * page_size end = start + page_size assert ids[start:end] == [UUID(company['id']) for company in response.data['results']]
def test_pagination(self, public_company_api_client, es_with_collector): """Test result pagination.""" total_records = 9 page_size = 2 ids = sorted((uuid4() for _ in range(total_records))) name = 'test record' CompanyFactory.create_batch( len(ids), id=factory.Iterator(ids), name=name, trading_names=[], ) es_with_collector.flush_and_refresh() url = reverse('api-v4:search:public-company') num_pages = (total_records + page_size - 1) // page_size for page in range(num_pages): request_data = { 'original_query': name, 'offset': page * page_size, 'limit': page_size, } response = public_company_api_client.post(url, request_data) assert response.status_code == status.HTTP_200_OK start = page * page_size end = start + page_size assert [ UUID(company['id']) for company in response.data['results'] ] == ids[start:end]
def test_one_list_account_manager_with_global_headquarters_filter( self, es_with_collector): """ Tests that one list account manager filter searches for inherited one list account manager. """ account_manager = AdviserFactory() CompanyFactory.create_batch(2) global_headquarters = CompanyFactory( one_list_account_owner=account_manager, ) target_companies = CompanyFactory.create_batch( 2, global_headquarters=global_headquarters) es_with_collector.flush_and_refresh() query = {'one_list_group_global_account_manager': account_manager.pk} url = reverse('api-v4:search:company') response = self.api_client.post(url, query) assert response.status_code == status.HTTP_200_OK search_results = { company['id'] for company in response.data['results'] } expected_results = { str(global_headquarters.id), *{str(target_company.id) for target_company in target_companies}, } assert response.data['count'] == 3 assert len(response.data['results']) == 3 assert search_results == expected_results
def test_sector_descends_filter(self, hierarchical_sectors, setup_es, sector_level): """Test the sector_descends filter.""" num_sectors = len(hierarchical_sectors) sectors_ids = [sector.pk for sector in hierarchical_sectors] companies = CompanyFactory.create_batch( num_sectors, sector_id=factory.Iterator(sectors_ids), ) CompanyFactory.create_batch( 3, sector=factory.LazyFunction(lambda: random_obj_for_queryset( Sector.objects.exclude(pk__in=sectors_ids), )), ) setup_es.indices.refresh() url = reverse('api-v3:search:company') body = { 'sector_descends': hierarchical_sectors[sector_level].pk, } response = self.api_client.post(url, body) assert response.status_code == status.HTTP_200_OK response_data = response.json() assert response_data['count'] == num_sectors - sector_level actual_ids = {UUID(company['id']) for company in response_data['results']} expected_ids = {company.pk for company in companies[sector_level:]} assert actual_ids == expected_ids
def test_sync_app_uses_latest_data(monkeypatch, setup_es): """Test that sync_app() picks up updates made to records between batches.""" CompanyFactory.create_batch(2, name='old name') def sync_objects_side_effect(*args, **kwargs): nonlocal mock_sync_objects ret = sync_objects(*args, **kwargs) if mock_sync_objects.call_count == 1: Company.objects.update(name='new name') return ret mock_sync_objects = Mock(side_effect=sync_objects_side_effect) monkeypatch.setattr('datahub.search.bulk_sync.sync_objects', mock_sync_objects) sync_app(CompanySearchApp, batch_size=1) setup_es.indices.refresh() company = mock_sync_objects.call_args_list[1][0][1][0] fetched_company = setup_es.get( index=CompanySearchApp.es_model.get_read_alias(), doc_type=CompanySearchApp.name, id=company.pk, ) assert fetched_company['_source']['name'] == 'new name'
def test_archived_filter(self, setup_es, archived): """Tests filtering by archived.""" matching_companies = CompanyFactory.create_batch(5, archived=archived) CompanyFactory.create_batch(2, archived=not archived) setup_es.indices.refresh() url = reverse('api-v3:search:company') response = self.api_client.post( url, data={ 'archived': archived, }, ) assert response.status_code == status.HTTP_200_OK response_data = response.json() assert response_data['count'] == 5 expected_ids = Counter( str(company.pk) for company in matching_companies) actual_ids = Counter(result['id'] for result in response_data['results']) assert expected_ids == actual_ids
def test_company_search_paging(self, setup_es, sortby): """Tests if content placement is consistent between pages.""" ids = sorted((uuid4() for _ in range(9))) name = 'test record' CompanyFactory.create_batch( len(ids), id=factory.Iterator(ids), name=name, alias='', ) setup_es.indices.refresh() page_size = 2 for page in range((len(ids) + page_size - 1) // page_size): url = reverse('api-v3:search:company') response = self.api_client.post( url, data={ 'original_query': name, 'entity': 'company', 'offset': page * page_size, 'limit': page_size, **sortby, }, ) assert response.status_code == status.HTTP_200_OK start = page * page_size end = start + page_size assert ids[start:end] == [UUID(company['id']) for company in response.data['results']]
def test_run_without_any_matches(caplog, simulate): """ Test that if no matches are found in the database, the command reports zero results. """ CompanyFactory.create_batch(5) caplog.set_level('INFO') call_command('cleanse_companies_using_worldbase_match', simulate=simulate) assert caplog.messages == [ 'Started', 'Finished - succeeded: 0, failed: 0, archived: 0', ]
def test_sector_descends_filter_for_company_interaction( self, hierarchical_sectors, es_with_collector, sector_level, ): """Test the sector_descends filter with company interactions.""" num_sectors = len(hierarchical_sectors) sectors_ids = [sector.pk for sector in hierarchical_sectors] companies = CompanyFactory.create_batch( num_sectors, sector_id=factory.Iterator(sectors_ids), ) company_interactions = CompanyInteractionFactory.create_batch( 3, company=factory.Iterator(companies), ) other_companies = CompanyFactory.create_batch( 3, sector=factory.LazyFunction(lambda: random_obj_for_queryset( Sector.objects.exclude(pk__in=sectors_ids), )), ) CompanyInteractionFactory.create_batch( 3, company=factory.Iterator(other_companies), ) es_with_collector.flush_and_refresh() url = reverse('api-v3:search:interaction') body = { 'sector_descends': hierarchical_sectors[sector_level].pk, } response = self.api_client.post(url, body) assert response.status_code == status.HTTP_200_OK response_data = response.json() assert response_data['count'] == num_sectors - sector_level actual_ids = { UUID(interaction['id']) for interaction in response_data['results'] } expected_ids = { interaction.pk for interaction in company_interactions[sector_level:] } assert actual_ids == expected_ids
def test_export( self, setup_es, request_sortby, orm_ordering, ): """Test export of company search results.""" CompanyFactory.create_batch(3) CompanyFactory.create_batch(2, hq=True) setup_es.indices.refresh() data = {} if request_sortby: data['sortby'] = request_sortby url = reverse('api-v3:search:company-export') with freeze_time('2018-01-01 11:12:13'): response = self.api_client.post(url, data=data) assert response.status_code == status.HTTP_200_OK assert parse_header(response.get('Content-Type')) == ('text/csv', {'charset': 'utf-8'}) assert parse_header(response.get('Content-Disposition')) == ( 'attachment', {'filename': 'Data Hub - Companies - 2018-01-01-11-12-13.csv'}, ) sorted_company = Company.objects.order_by(orm_ordering, 'pk') reader = DictReader(StringIO(response.getvalue().decode('utf-8-sig'))) assert reader.fieldnames == list(SearchCompanyExportAPIView.field_titles.values()) expected_row_data = [ { 'Name': company.name, 'Link': f'{settings.DATAHUB_FRONTEND_URL_PREFIXES["company"]}/{company.pk}', 'Sector': get_attr_or_none(company, 'sector.name'), 'Country': get_attr_or_none(company, 'registered_address_country.name'), 'UK region': get_attr_or_none(company, 'uk_region.name'), 'Archived': company.archived, 'Date created': company.created_on, 'Number of employees': get_attr_or_none(company, 'employee_range.name'), 'Annual turnover': get_attr_or_none(company, 'turnover_range.name'), 'Headquarter type': (get_attr_or_none(company, 'headquarter_type.name') or '').upper(), } for company in sorted_company ] assert list(dict(row) for row in reader) == format_csv_data(expected_row_data)
def test_run(s3_stubber, caplog): """Test that the command updates the specified records (ignoring ones with errors).""" caplog.set_level('ERROR') original_datetime = datetime(2017, 1, 1, tzinfo=timezone.utc) with freeze_time(original_datetime): export_potential_scores = [ Company.EXPORT_POTENTIAL_SCORES.very_high, Company.EXPORT_POTENTIAL_SCORES.medium, Company.EXPORT_POTENTIAL_SCORES.low, Company.EXPORT_POTENTIAL_SCORES.very_high, Company.EXPORT_POTENTIAL_SCORES.high, ] companies = CompanyFactory.create_batch( 5, export_potential=factory.Iterator(export_potential_scores), ) bucket = 'test_bucket' object_key = 'test_key' csv_content = f"""datahub_company_id,export_propensity 00000000-0000-0000-0000-000000000000,Low {companies[0].pk},High {companies[1].pk},Very high {companies[2].pk},dummy {companies[3].pk},High {companies[4].pk},Very high """ s3_stubber.add_response( 'get_object', { 'Body': BytesIO(csv_content.encode(encoding='utf-8')), }, expected_params={ 'Bucket': bucket, 'Key': object_key, }, ) with freeze_time('2018-11-11 00:00:00'): call_command('update_company_export_potential', bucket, object_key) for company in companies: company.refresh_from_db() assert 'Company matching query does not exist' in caplog.text assert "KeyError: \'dummy\'" in caplog.text assert len(caplog.records) == 2 assert [company.export_potential for company in companies] == [ Company.EXPORT_POTENTIAL_SCORES.high, Company.EXPORT_POTENTIAL_SCORES.very_high, Company.EXPORT_POTENTIAL_SCORES.low, Company.EXPORT_POTENTIAL_SCORES.high, Company.EXPORT_POTENTIAL_SCORES.very_high, ] assert all(company.modified_on == original_datetime for company in companies)
def test_limit( self, automatic_company_archive_feature_flag, ): """ Test that we can set a limit to the number of companies that are automatically archived. """ gt_3m_ago = timezone.now() - relativedelta(months=3, days=1) with freeze_time(gt_3m_ago): companies = CompanyFactory.create_batch(3) task_result = automatic_company_archive.apply_async( kwargs={ 'simulate': False, 'limit': 2, }, ) assert task_result.successful() archived_companies_count = 0 for company in companies: company.refresh_from_db() if company.archived: archived_companies_count += 1 assert archived_companies_count == 2
def test_audit_log(s3_stubber): """Test that the audit log is being created.""" investment_project = InvestmentProjectFactory() file_companies = CompanyFactory.create_batch(3) bucket = 'test_bucket' object_key = 'test_key' csv_content = f"""id,investor_company_id,intermediate_company_id,uk_company_id,uk_company_decided {investment_project.id},{file_companies[0].pk},{file_companies[1].pk},{file_companies[2].pk},1 """ s3_stubber.add_response( 'get_object', { 'Body': BytesIO(bytes(csv_content, encoding='utf-8')), }, expected_params={ 'Bucket': bucket, 'Key': object_key, }, ) call_command('update_investment_project_company', bucket, object_key) investment_project.refresh_from_db() assert investment_project.investor_company == file_companies[0] assert investment_project.intermediate_company == file_companies[1] assert investment_project.uk_company == file_companies[2] assert investment_project.uk_company_decided is True versions = Version.objects.get_for_object(investment_project) assert len(versions) == 1 assert versions[0].revision.get_comment() == 'Companies data migration.'
def test_with_multiple_user_lists(self): """Test that user who owns multiple lists can list all their contents.""" lists = CompanyListFactory.create_batch(5, adviser=self.user) company_list_companies = {} # add multiple companies to user's lists for company_list in lists: companies = CompanyFactory.create_batch(5) company_list_companies[company_list.pk] = companies CompanyListItemFactory.create_batch( len(companies), list=company_list, company=factory.Iterator(companies), ) # check if contents of each user's list can be listed for company_list in lists: url = _get_list_item_collection_url(company_list.pk) response = self.api_client.get(url) assert response.status_code == status.HTTP_200_OK response_data = response.json() result_company_ids = { result['company']['id'] for result in response_data['results'] } assert result_company_ids == { str(company.id) for company in company_list_companies[company_list.pk] }
def test_company_dbmodels_to_es_documents(self, setup_es): """Tests conversion of db models to Elasticsearch documents.""" companies = CompanyFactory.create_batch(2) result = ESCompany.db_objects_to_es_documents(companies) assert len(list(result)) == len(companies)
def test_simulate(s3_stubber): """Test that the command only simulates the actions if --simulate is passed in.""" companies = CompanyFactory.create_batch(3) investment_projects = InvestmentProjectFactory.create_batch( 2, investor_company=companies[0], intermediate_company=companies[1], uk_company=companies[2], uk_company_decided=True, ) file_companies = CompanyFactory.create_batch(3) bucket = 'test_bucket' object_key = 'test_key' csv_content = f"""id,investor_company_id,intermediate_company_id,uk_company_id,uk_company_decided {investment_projects[0].id},{file_companies[0].pk},{file_companies[1].pk},{file_companies[2].pk},1 {investment_projects[1].id},{file_companies[0].pk},{file_companies[1].pk},{file_companies[2].pk},1 """ s3_stubber.add_response( 'get_object', { 'Body': BytesIO(bytes(csv_content, encoding='utf-8')), }, expected_params={ 'Bucket': bucket, 'Key': object_key, }, ) call_command('update_investment_project_company', bucket, object_key, simulate=True) for investment_project in investment_projects: investment_project.refresh_from_db() assert investment_projects[0].investor_company == companies[0] assert investment_projects[0].intermediate_company == companies[1] assert investment_projects[0].uk_company == companies[2] assert investment_projects[0].uk_company_decided is True assert investment_projects[1].investor_company == companies[0] assert investment_projects[1].intermediate_company == companies[1] assert investment_projects[1].uk_company == companies[2] assert investment_projects[1].uk_company_decided is True
def test_run(s3_stubber, caplog): """Test that the command updates the specified records (ignoring ones with errors).""" caplog.set_level('ERROR') companies = CompanyFactory.create_batch(4) data = [] for company in companies: record = _get_data_for_company(company.id) raw_record = json.dumps(record) base64_record = b64encode(raw_record.encode('utf-8')).decode('utf-8') data.append(base64_record) wrong_json = b64encode('{"what": }'.encode('utf-8')).decode('utf-8') # to check that existing match will be overwritten company_2_match = DnBMatchingResult.objects.create( company_id=companies[2].id, data={'hello': 'world', 'confidence': 100}, ) bucket = 'test_bucket' object_key = 'test_key' csv_content = f"""id,data 00000000-0000-0000-0000-000000000000,NULL {companies[0].id},{data[0]} {companies[1].id},{wrong_json} {companies[2].id},{data[2]} {companies[3].id},invalidbase64 """ s3_stubber.add_response( 'get_object', { 'Body': BytesIO(csv_content.encode(encoding='utf-8')), }, expected_params={ 'Bucket': bucket, 'Key': object_key, }, ) call_command('update_company_matches', bucket, object_key) assert 'Company matching query does not exist' in caplog.text assert 'json.decoder.JSONDecodeError' in caplog.text assert 'binascii.Error: Invalid base64-encoded string' in caplog.text assert len(caplog.records) == 3 matches = DnBMatchingResult.objects.filter(company__in=companies) assert matches.count() == 2 for match in matches: assert match.data == _get_data_for_company(match.company_id) company_2_match.refresh_from_db() assert company_2_match.data == _get_data_for_company(company_2_match.company_id)
def test_company_dbmodels_to_documents(self, opensearch): """Tests conversion of db models to OpenSearch documents.""" companies = CompanyFactory.create_batch(2) app = get_search_app('company') companies_qs = app.queryset.all() result = SearchCompany.db_objects_to_documents(companies_qs) assert len(list(result)) == len(companies)
def test_no_change(s3_stubber, caplog): """Test that the command ignores records that haven't changed or records with incorrect current values. """ caplog.set_level('WARNING') old_sectors = SectorFactory.create_batch( 3, segment=factory.Iterator( ['sector_1_old', 'sector_2_old', 'sector_3_old']), ) companies = CompanyFactory.create_batch( 3, sector_id=factory.Iterator([sector.pk for sector in old_sectors]), ) new_sectors = SectorFactory.create_batch( 3, segment=factory.Iterator( ['sector_1_new', 'sector_2_new', 'sector_3_new']), ) bucket = 'test_bucket' object_key = 'test_key' csv_content = f"""id,old_sector_id,new_sector_id {companies[0].pk},{old_sectors[0].pk},{new_sectors[0].pk} {companies[1].pk},{old_sectors[1].pk},{old_sectors[1].pk} {companies[2].pk},00000000-0000-0000-0000-000000000000,{new_sectors[2].pk} """ s3_stubber.add_response( 'get_object', { 'Body': BytesIO(csv_content.encode(encoding='utf-8')), }, expected_params={ 'Bucket': bucket, 'Key': object_key, }, ) call_command('update_company_sector_disabled_signals', bucket, object_key) for company in companies: company.refresh_from_db() assert f'Not updating company {companies[1]} as its sector has not changed' in caplog.text assert f'Not updating company {companies[2]} as its sector has not changed' in caplog.text assert len(caplog.records) == 2 assert [company.sector for company in companies] == [ new_sectors[0], old_sectors[1], old_sectors[2], ]
def company_names_and_postcodes(opensearch_with_collector): """Get companies with postcodes.""" (names, postcodes) = zip(*( ('company_w1', 'w1 2AB'), # AB in suffix to ensure not matched in AB tests ('company_w1a', 'W1A2AB'), # AB in suffix to ensure not matched in AB tests ('company_w11', 'W112AB'), # AB in suffix to ensure not matched in AB tests ('company_ab1_1', 'AB11WC'), # WC in suffix to ensure not matched in WC tests ('company_ab10', 'ab10 1WC'), # WC in suffix to ensure not matched in WC tests # to test the difference between searching for AB1 0 (sector) and AB10 (district) ('company_ab1_0', 'AB1 0WC'), ('company_wc2b', 'WC2B4AB'), # AB in suffix to ensure not matched in AB tests ('company_wc2n', 'WC2N9ZZ'), ('company_wc1x', 'w C 1 x0aA'), ('company_wc1a', 'W C 1 A 1 G A'), ('company_se1', 'SE13A J'), ('company_se1_3', 'SE13AJ'), ('company_se2', 'SE23AJ'), ('company_se3', 'SE33AJ'), )) CompanyFactory.create_batch( len(names), name=factory.Iterator(names), address_country_id=constants.Country.united_kingdom.value.id, address_postcode=factory.Iterator(postcodes), registered_address_country_id=constants.Country.united_kingdom.value. id, registered_address_postcode=factory.Iterator(postcodes), ) CompanyFactory( name='non_uk_company_se1', address_country_id=constants.Country.united_states.value.id, address_postcode='SE13AJ', ) opensearch_with_collector.flush_and_refresh()
def test_one_list_download(self): """Test the download of the One List.""" CompanyFactory.create_batch( 2, headquarter_type_id=constants.HeadquarterType.ghq.value.id, classification=random_obj_for_model(CompanyClassification), one_list_account_owner=AdviserFactory(), ) # ignored because headquarter_type is None CompanyFactory( headquarter_type=None, classification=random_obj_for_model(CompanyClassification), one_list_account_owner=AdviserFactory(), ) # ignored because classification is None CompanyFactory( headquarter_type_id=constants.HeadquarterType.ghq.value.id, classification=None, one_list_account_owner=AdviserFactory(), ) # ignored because one_list_account_owner is None CompanyFactory( headquarter_type_id=constants.HeadquarterType.ghq.value.id, classification=random_obj_for_model(CompanyClassification), one_list_account_owner=None, ) url = reverse('admin-report:download-report', kwargs={'report_id': 'one-list'}) user = create_test_user( permission_codenames=('view_company', ), is_staff=True, password=self.PASSWORD, ) client = self.create_client(user=user) response = client.get(url) assert response.status_code == status.HTTP_200_OK # 3 = header + the first 2 companies assert len(response.getvalue().decode('utf-8').splitlines()) == 3
def test_non_existent_company(s3_stubber, caplog): """Test that the command logs an error when the company PK does not exist.""" caplog.set_level('ERROR') old_sectors = SectorFactory.create_batch( 3, segment=factory.Iterator( ['sector_1_old', 'sector_2_old', 'sector_3_old']), ) companies = CompanyFactory.create_batch( 3, sector_id=factory.Iterator([sector.pk for sector in old_sectors]), ) new_sectors = SectorFactory.create_batch( 3, segment=factory.Iterator( ['sector_1_new', 'sector_2_new', 'sector_3_new']), ) bucket = 'test_bucket' object_key = 'test_key' csv_content = f"""id,old_sector_id,new_sector_id {companies[0].pk},{old_sectors[0].pk},{new_sectors[0].pk} {companies[1].pk},{old_sectors[1].pk},{new_sectors[1].pk} 00000000-0000-0000-0000-000000000000,{old_sectors[2].pk},{new_sectors[2].pk} """ s3_stubber.add_response( 'get_object', { 'Body': BytesIO(csv_content.encode(encoding='utf-8')), }, expected_params={ 'Bucket': bucket, 'Key': object_key, }, ) call_command('update_company_sector_disabled_signals', bucket, object_key) for company in companies: company.refresh_from_db() assert 'Company matching query does not exist' in caplog.text assert len(caplog.records) == 1 assert [company.sector for company in companies] == [ new_sectors[0], new_sectors[1], old_sectors[2], ]
def test_company_search_paging(self, opensearch_with_collector): """ Tests the pagination. The sortby is not passed in so records are ordered by id. """ total_records = 9 page_size = 2 ids = sorted((uuid4() for _ in range(total_records))) name = 'test record' CompanyFactory.create_batch( len(ids), id=factory.Iterator(ids), name=name, trading_names=[], ) opensearch_with_collector.flush_and_refresh() url = reverse('api-v4:search:company') for page in range((len(ids) + page_size - 1) // page_size): response = self.api_client.post( url, data={ 'original_query': name, 'offset': page * page_size, 'limit': page_size, }, ) assert response.status_code == status.HTTP_200_OK start = page * page_size end = start + page_size assert [ UUID(company['id']) for company in response.data['results'] ] == ids[start:end]
def test_one_list_account_manager_filter( self, num_account_managers, opensearch_with_collector, ): """Test one list account manager filter.""" account_managers = AdviserFactory.create_batch(3) selected_account_managers = random.sample(account_managers, num_account_managers) CompanyFactory.create_batch(2) CompanyFactory.create_batch( 3, one_list_account_owner=factory.Iterator(account_managers)) opensearch_with_collector.flush_and_refresh() query = { 'one_list_group_global_account_manager': [ account_manager.id for account_manager in selected_account_managers ], } url = reverse('api-v4:search:company') response = self.api_client.post(url, query) assert response.status_code == status.HTTP_200_OK search_results = { company['one_list_group_global_account_manager']['id'] for company in response.data['results'] } expected_results = { str(account_manager.id) for account_manager in selected_account_managers } assert response.data['count'] == len(selected_account_managers) assert len(response.data['results']) == len(selected_account_managers) assert search_results == expected_results
def test_simulate(s3_stubber, caplog): """Test that the command simulates updates if --simulate is passed in.""" caplog.set_level('ERROR') export_potential_scores = [ Company.EXPORT_POTENTIAL_SCORES.very_high, Company.EXPORT_POTENTIAL_SCORES.medium, Company.EXPORT_POTENTIAL_SCORES.low, Company.EXPORT_POTENTIAL_SCORES.very_high, Company.EXPORT_POTENTIAL_SCORES.high, ] companies = CompanyFactory.create_batch( 5, export_potential=factory.Iterator(export_potential_scores), ) bucket = 'test_bucket' object_key = 'test_key' csv_content = f"""datahub_company_id,export_propensity 00000000-0000-0000-0000-000000000000,Low {companies[0].pk},High {companies[1].pk},Very high {companies[2].pk},Low {companies[3].pk},High {companies[4].pk},Very high """ s3_stubber.add_response( 'get_object', { 'Body': BytesIO(csv_content.encode(encoding='utf-8')), }, expected_params={ 'Bucket': bucket, 'Key': object_key, }, ) call_command('update_company_export_potential', bucket, object_key, simulate=True) for company in companies: company.refresh_from_db() assert 'Company matching query does not exist' in caplog.text assert len(caplog.records) == 1 assert [company.export_potential for company in companies] == export_potential_scores
def test_company_subsidiaries_auto_update_to_opensearch( opensearch_with_signals): """Tests if company subsidiaries get updated in OpenSearch.""" account_owner = AdviserFactory() global_headquarters = CompanyFactory(one_list_account_owner=account_owner) subsidiaries = CompanyFactory.create_batch( 2, global_headquarters=global_headquarters) opensearch_with_signals.indices.refresh() subsidiary_ids = [subsidiary.id for subsidiary in subsidiaries] result = get_documents_by_ids( opensearch_with_signals, CompanySearchApp, subsidiary_ids, ) expected_results = {(str(subsidiary_id), str(account_owner.id)) for subsidiary_id in subsidiary_ids} search_results = { (doc['_id'], doc['_source']['one_list_group_global_account_manager']['id']) for doc in result['docs'] } assert len(result['docs']) == 2 assert search_results == expected_results new_account_owner = AdviserFactory() global_headquarters.one_list_account_owner = new_account_owner global_headquarters.save() opensearch_with_signals.indices.refresh() new_result = get_documents_by_ids( opensearch_with_signals, CompanySearchApp, subsidiary_ids, ) new_expected_results = {(str(subsidiary_id), str(new_account_owner.id)) for subsidiary_id in subsidiary_ids} new_search_results = { (doc['_id'], doc['_source']['one_list_group_global_account_manager']['id']) for doc in new_result['docs'] } assert len(new_result['docs']) == 2 assert new_search_results == new_expected_results
def test_one_list_report_generation(): """Test the generation of the One List.""" companies = CompanyFactory.create_batch( 2, headquarter_type_id=constants.HeadquarterType.ghq.value.id, classification=factory.Iterator( CompanyClassification.objects.all(), # keeps the ordering ), one_list_account_owner=AdviserFactory(), ) # ignored because headquarter_type is None CompanyFactory( headquarter_type=None, classification=random_obj_for_model(CompanyClassification), one_list_account_owner=AdviserFactory(), ) # ignored because classification is None CompanyFactory( headquarter_type_id=constants.HeadquarterType.ghq.value.id, classification=None, one_list_account_owner=AdviserFactory(), ) # ignored because one_list_account_owner is None CompanyFactory( headquarter_type_id=constants.HeadquarterType.ghq.value.id, classification=random_obj_for_model(CompanyClassification), one_list_account_owner=None, ) report = OneListReport() assert list(report.rows()) == [{ 'name': company.name, 'classification__name': company.classification.name, 'sector__segment': company.sector.segment, 'primary_contact_name': company.one_list_account_owner.name, 'one_list_account_owner__telephone_number': company.one_list_account_owner.telephone_number, 'one_list_account_owner__contact_email': company.one_list_account_owner.contact_email, 'registered_address_country__name': company.registered_address_country.name, 'registered_address_town': company.registered_address_town, 'url': f'{settings.DATAHUB_FRONTEND_URL_PREFIXES["company"]}/{company.id}', } for company in companies]
def test_global_headquarters(self, setup_es): """Test global headquarters filter.""" ghq1 = CompanyFactory(headquarter_type_id=constants.HeadquarterType.ghq.value.id) ghq2 = CompanyFactory(headquarter_type_id=constants.HeadquarterType.ghq.value.id) companies = CompanyFactory.create_batch(5, global_headquarters=ghq1) CompanyFactory.create_batch(5, global_headquarters=ghq2) CompanyFactory.create_batch(10) setup_es.indices.refresh() url = reverse('api-v3:search:company') response = self.api_client.post( url, { 'global_headquarters': ghq1.id, }, ) assert response.status_code == status.HTTP_200_OK assert response.data['count'] == 5 assert len(response.data['results']) == 5 search_results = {UUID(company['id']) for company in response.data['results']} assert search_results == {company.id for company in companies}
def test_run(s3_stubber, caplog): """Test that the command updates the specified records (ignoring ones with errors).""" caplog.set_level('ERROR') company_aliases = ('abc', 'def', 'ghi', 'jkl', 'mno') companies = CompanyFactory.create_batch( 5, alias=factory.Iterator(company_aliases), ) bucket = 'test_bucket' object_key = 'test_key' csv_content = f"""id,old_company_alias,new_company_alias 00000000-0000-0000-0000-000000000000,test,test {companies[0].pk},{companies[0].alias},xyz100 {companies[1].pk},{companies[1].alias},xyz102 {companies[2].pk},what,xyz103 {companies[3].pk},{companies[3].alias},xyz104 {companies[4].pk},{companies[4].alias},xyz105 """ s3_stubber.add_response( 'get_object', { 'Body': BytesIO(csv_content.encode(encoding='utf-8')), }, expected_params={ 'Bucket': bucket, 'Key': object_key, }, ) call_command('update_company_alias', bucket, object_key) for company in companies: company.refresh_from_db() assert 'Company matching query does not exist' in caplog.text assert len(caplog.records) == 1 assert [company.alias for company in companies] == [ 'xyz100', 'xyz102', 'ghi', 'xyz104', 'xyz105', ]
def test_run(s3_stubber, caplog): """Test that the command updates the specified records (ignoring ones with errors).""" caplog.set_level('ERROR') original_datetime = datetime(2017, 1, 1, tzinfo=timezone.utc) with freeze_time(original_datetime): company_numbers = ['123', '456', '466879', '', None] companies = CompanyFactory.create_batch( 5, company_number=factory.Iterator(company_numbers), ) bucket = 'test_bucket' object_key = 'test_key' csv_content = f"""id,company_number 00000000-0000-0000-0000-000000000000,123456 {companies[0].pk},012345 {companies[1].pk},456 {companies[2].pk},null {companies[3].pk},087891 {companies[4].pk},087892 """ s3_stubber.add_response( 'get_object', { 'Body': BytesIO(csv_content.encode(encoding='utf-8')), }, expected_params={ 'Bucket': bucket, 'Key': object_key, }, ) with freeze_time('2018-11-11 00:00:00'): call_command('update_company_company_number', bucket, object_key) for company in companies: company.refresh_from_db() assert 'Company matching query does not exist' in caplog.text assert len(caplog.records) == 1 assert [company.company_number for company in companies] == [ '012345', '456', '', '087891', '087892', ] assert all(company.modified_on == original_datetime for company in companies)