Пример #1
0
    def test_pagination(self, setup_es):
        """Tests the pagination."""
        total_records = 9
        page_size = 2

        ids = sorted((uuid4() for _ in range(total_records)))

        name = 'test record'

        CompanyFactory.create_batch(
            len(ids),
            id=factory.Iterator(ids),
            name=name,
            trading_names=[],
        )

        setup_es.indices.refresh()

        url = reverse('api-v3:search:basic')
        for page in range((len(ids) + page_size - 1) // page_size):
            response = self.api_client.get(
                url,
                data={
                    'term': name,
                    'entity': 'company',
                    'offset': page * page_size,
                    'limit': page_size,
                },
            )

            assert response.status_code == status.HTTP_200_OK

            start = page * page_size
            end = start + page_size
            assert ids[start:end] == [UUID(company['id']) for company in response.data['results']]
Пример #2
0
    def test_pagination(self, public_company_api_client, es_with_collector):
        """Test result pagination."""
        total_records = 9
        page_size = 2
        ids = sorted((uuid4() for _ in range(total_records)))
        name = 'test record'

        CompanyFactory.create_batch(
            len(ids),
            id=factory.Iterator(ids),
            name=name,
            trading_names=[],
        )

        es_with_collector.flush_and_refresh()

        url = reverse('api-v4:search:public-company')

        num_pages = (total_records + page_size - 1) // page_size
        for page in range(num_pages):
            request_data = {
                'original_query': name,
                'offset': page * page_size,
                'limit': page_size,
            }
            response = public_company_api_client.post(url, request_data)
            assert response.status_code == status.HTTP_200_OK

            start = page * page_size
            end = start + page_size
            assert [
                UUID(company['id']) for company in response.data['results']
            ] == ids[start:end]
Пример #3
0
    def test_one_list_account_manager_with_global_headquarters_filter(
            self, es_with_collector):
        """
        Tests that one list account manager filter searches for inherited one list account manager.
        """
        account_manager = AdviserFactory()
        CompanyFactory.create_batch(2)

        global_headquarters = CompanyFactory(
            one_list_account_owner=account_manager, )
        target_companies = CompanyFactory.create_batch(
            2, global_headquarters=global_headquarters)

        es_with_collector.flush_and_refresh()

        query = {'one_list_group_global_account_manager': account_manager.pk}

        url = reverse('api-v4:search:company')
        response = self.api_client.post(url, query)

        assert response.status_code == status.HTTP_200_OK

        search_results = {
            company['id']
            for company in response.data['results']
        }
        expected_results = {
            str(global_headquarters.id),
            *{str(target_company.id)
              for target_company in target_companies},
        }
        assert response.data['count'] == 3
        assert len(response.data['results']) == 3
        assert search_results == expected_results
Пример #4
0
    def test_sector_descends_filter(self, hierarchical_sectors, setup_es, sector_level):
        """Test the sector_descends filter."""
        num_sectors = len(hierarchical_sectors)
        sectors_ids = [sector.pk for sector in hierarchical_sectors]

        companies = CompanyFactory.create_batch(
            num_sectors,
            sector_id=factory.Iterator(sectors_ids),
        )
        CompanyFactory.create_batch(
            3,
            sector=factory.LazyFunction(lambda: random_obj_for_queryset(
                Sector.objects.exclude(pk__in=sectors_ids),
            )),
        )

        setup_es.indices.refresh()

        url = reverse('api-v3:search:company')
        body = {
            'sector_descends': hierarchical_sectors[sector_level].pk,
        }
        response = self.api_client.post(url, body)
        assert response.status_code == status.HTTP_200_OK

        response_data = response.json()
        assert response_data['count'] == num_sectors - sector_level

        actual_ids = {UUID(company['id']) for company in response_data['results']}
        expected_ids = {company.pk for company in companies[sector_level:]}
        assert actual_ids == expected_ids
Пример #5
0
def test_sync_app_uses_latest_data(monkeypatch, setup_es):
    """Test that sync_app() picks up updates made to records between batches."""
    CompanyFactory.create_batch(2, name='old name')

    def sync_objects_side_effect(*args, **kwargs):
        nonlocal mock_sync_objects

        ret = sync_objects(*args, **kwargs)

        if mock_sync_objects.call_count == 1:
            Company.objects.update(name='new name')

        return ret

    mock_sync_objects = Mock(side_effect=sync_objects_side_effect)
    monkeypatch.setattr('datahub.search.bulk_sync.sync_objects', mock_sync_objects)
    sync_app(CompanySearchApp, batch_size=1)

    setup_es.indices.refresh()

    company = mock_sync_objects.call_args_list[1][0][1][0]
    fetched_company = setup_es.get(
        index=CompanySearchApp.es_model.get_read_alias(),
        doc_type=CompanySearchApp.name,
        id=company.pk,
    )
    assert fetched_company['_source']['name'] == 'new name'
    def test_archived_filter(self, setup_es, archived):
        """Tests filtering by archived."""
        matching_companies = CompanyFactory.create_batch(5, archived=archived)
        CompanyFactory.create_batch(2, archived=not archived)

        setup_es.indices.refresh()

        url = reverse('api-v3:search:company')

        response = self.api_client.post(
            url,
            data={
                'archived': archived,
            },
        )

        assert response.status_code == status.HTTP_200_OK
        response_data = response.json()
        assert response_data['count'] == 5

        expected_ids = Counter(
            str(company.pk) for company in matching_companies)
        actual_ids = Counter(result['id']
                             for result in response_data['results'])
        assert expected_ids == actual_ids
Пример #7
0
    def test_company_search_paging(self, setup_es, sortby):
        """Tests if content placement is consistent between pages."""
        ids = sorted((uuid4() for _ in range(9)))

        name = 'test record'

        CompanyFactory.create_batch(
            len(ids),
            id=factory.Iterator(ids),
            name=name,
            alias='',
        )

        setup_es.indices.refresh()

        page_size = 2

        for page in range((len(ids) + page_size - 1) // page_size):
            url = reverse('api-v3:search:company')
            response = self.api_client.post(
                url,
                data={
                    'original_query': name,
                    'entity': 'company',
                    'offset': page * page_size,
                    'limit': page_size,
                    **sortby,
                },
            )

            assert response.status_code == status.HTTP_200_OK

            start = page * page_size
            end = start + page_size
            assert ids[start:end] == [UUID(company['id']) for company in response.data['results']]
def test_run_without_any_matches(caplog, simulate):
    """
    Test that if no matches are found in the database, the command reports zero results.
    """
    CompanyFactory.create_batch(5)

    caplog.set_level('INFO')

    call_command('cleanse_companies_using_worldbase_match', simulate=simulate)

    assert caplog.messages == [
        'Started',
        'Finished - succeeded: 0, failed: 0, archived: 0',
    ]
Пример #9
0
    def test_sector_descends_filter_for_company_interaction(
        self,
        hierarchical_sectors,
        es_with_collector,
        sector_level,
    ):
        """Test the sector_descends filter with company interactions."""
        num_sectors = len(hierarchical_sectors)
        sectors_ids = [sector.pk for sector in hierarchical_sectors]

        companies = CompanyFactory.create_batch(
            num_sectors,
            sector_id=factory.Iterator(sectors_ids),
        )
        company_interactions = CompanyInteractionFactory.create_batch(
            3,
            company=factory.Iterator(companies),
        )

        other_companies = CompanyFactory.create_batch(
            3,
            sector=factory.LazyFunction(lambda: random_obj_for_queryset(
                Sector.objects.exclude(pk__in=sectors_ids), )),
        )
        CompanyInteractionFactory.create_batch(
            3,
            company=factory.Iterator(other_companies),
        )

        es_with_collector.flush_and_refresh()

        url = reverse('api-v3:search:interaction')
        body = {
            'sector_descends': hierarchical_sectors[sector_level].pk,
        }
        response = self.api_client.post(url, body)
        assert response.status_code == status.HTTP_200_OK

        response_data = response.json()
        assert response_data['count'] == num_sectors - sector_level

        actual_ids = {
            UUID(interaction['id'])
            for interaction in response_data['results']
        }
        expected_ids = {
            interaction.pk
            for interaction in company_interactions[sector_level:]
        }
        assert actual_ids == expected_ids
Пример #10
0
    def test_export(
        self,
        setup_es,
        request_sortby,
        orm_ordering,
    ):
        """Test export of company search results."""
        CompanyFactory.create_batch(3)
        CompanyFactory.create_batch(2, hq=True)

        setup_es.indices.refresh()

        data = {}
        if request_sortby:
            data['sortby'] = request_sortby

        url = reverse('api-v3:search:company-export')

        with freeze_time('2018-01-01 11:12:13'):
            response = self.api_client.post(url, data=data)

        assert response.status_code == status.HTTP_200_OK
        assert parse_header(response.get('Content-Type')) == ('text/csv', {'charset': 'utf-8'})
        assert parse_header(response.get('Content-Disposition')) == (
            'attachment', {'filename': 'Data Hub - Companies - 2018-01-01-11-12-13.csv'},
        )

        sorted_company = Company.objects.order_by(orm_ordering, 'pk')
        reader = DictReader(StringIO(response.getvalue().decode('utf-8-sig')))

        assert reader.fieldnames == list(SearchCompanyExportAPIView.field_titles.values())

        expected_row_data = [
            {
                'Name': company.name,
                'Link': f'{settings.DATAHUB_FRONTEND_URL_PREFIXES["company"]}/{company.pk}',
                'Sector': get_attr_or_none(company, 'sector.name'),
                'Country': get_attr_or_none(company, 'registered_address_country.name'),
                'UK region': get_attr_or_none(company, 'uk_region.name'),
                'Archived': company.archived,
                'Date created': company.created_on,
                'Number of employees': get_attr_or_none(company, 'employee_range.name'),
                'Annual turnover': get_attr_or_none(company, 'turnover_range.name'),
                'Headquarter type':
                    (get_attr_or_none(company, 'headquarter_type.name') or '').upper(),
            }
            for company in sorted_company
        ]

        assert list(dict(row) for row in reader) == format_csv_data(expected_row_data)
Пример #11
0
def test_run(s3_stubber, caplog):
    """Test that the command updates the specified records (ignoring ones with errors)."""
    caplog.set_level('ERROR')

    original_datetime = datetime(2017, 1, 1, tzinfo=timezone.utc)

    with freeze_time(original_datetime):
        export_potential_scores = [
            Company.EXPORT_POTENTIAL_SCORES.very_high,
            Company.EXPORT_POTENTIAL_SCORES.medium,
            Company.EXPORT_POTENTIAL_SCORES.low,
            Company.EXPORT_POTENTIAL_SCORES.very_high,
            Company.EXPORT_POTENTIAL_SCORES.high,
        ]
        companies = CompanyFactory.create_batch(
            5,
            export_potential=factory.Iterator(export_potential_scores),
        )

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""datahub_company_id,export_propensity
00000000-0000-0000-0000-000000000000,Low
{companies[0].pk},High
{companies[1].pk},Very high
{companies[2].pk},dummy
{companies[3].pk},High
{companies[4].pk},Very high
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    with freeze_time('2018-11-11 00:00:00'):
        call_command('update_company_export_potential', bucket, object_key)

    for company in companies:
        company.refresh_from_db()

    assert 'Company matching query does not exist' in caplog.text
    assert "KeyError: \'dummy\'" in caplog.text
    assert len(caplog.records) == 2

    assert [company.export_potential for company in companies] == [
        Company.EXPORT_POTENTIAL_SCORES.high,
        Company.EXPORT_POTENTIAL_SCORES.very_high,
        Company.EXPORT_POTENTIAL_SCORES.low,
        Company.EXPORT_POTENTIAL_SCORES.high,
        Company.EXPORT_POTENTIAL_SCORES.very_high,
    ]
    assert all(company.modified_on == original_datetime
               for company in companies)
Пример #12
0
    def test_limit(
        self,
        automatic_company_archive_feature_flag,
    ):
        """
        Test that we can set a limit to the number of companies
        that are automatically archived.
        """
        gt_3m_ago = timezone.now() - relativedelta(months=3, days=1)
        with freeze_time(gt_3m_ago):
            companies = CompanyFactory.create_batch(3)
        task_result = automatic_company_archive.apply_async(
            kwargs={
                'simulate': False,
                'limit': 2,
            },
        )
        assert task_result.successful()

        archived_companies_count = 0
        for company in companies:
            company.refresh_from_db()
            if company.archived:
                archived_companies_count += 1

        assert archived_companies_count == 2
Пример #13
0
def test_audit_log(s3_stubber):
    """Test that the audit log is being created."""
    investment_project = InvestmentProjectFactory()
    file_companies = CompanyFactory.create_batch(3)

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,investor_company_id,intermediate_company_id,uk_company_id,uk_company_decided
{investment_project.id},{file_companies[0].pk},{file_companies[1].pk},{file_companies[2].pk},1
"""
    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(bytes(csv_content, encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_investment_project_company', bucket, object_key)

    investment_project.refresh_from_db()

    assert investment_project.investor_company == file_companies[0]
    assert investment_project.intermediate_company == file_companies[1]
    assert investment_project.uk_company == file_companies[2]
    assert investment_project.uk_company_decided is True

    versions = Version.objects.get_for_object(investment_project)
    assert len(versions) == 1
    assert versions[0].revision.get_comment() == 'Companies data migration.'
Пример #14
0
    def test_with_multiple_user_lists(self):
        """Test that user who owns multiple lists can list all their contents."""
        lists = CompanyListFactory.create_batch(5, adviser=self.user)

        company_list_companies = {}

        # add multiple companies to user's lists
        for company_list in lists:
            companies = CompanyFactory.create_batch(5)
            company_list_companies[company_list.pk] = companies

            CompanyListItemFactory.create_batch(
                len(companies),
                list=company_list,
                company=factory.Iterator(companies),
            )

        # check if contents of each user's list can be listed
        for company_list in lists:
            url = _get_list_item_collection_url(company_list.pk)
            response = self.api_client.get(url)

            assert response.status_code == status.HTTP_200_OK
            response_data = response.json()

            result_company_ids = {
                result['company']['id']
                for result in response_data['results']
            }
            assert result_company_ids == {
                str(company.id)
                for company in company_list_companies[company_list.pk]
            }
Пример #15
0
    def test_company_dbmodels_to_es_documents(self, setup_es):
        """Tests conversion of db models to Elasticsearch documents."""
        companies = CompanyFactory.create_batch(2)

        result = ESCompany.db_objects_to_es_documents(companies)

        assert len(list(result)) == len(companies)
Пример #16
0
def test_simulate(s3_stubber):
    """Test that the command only simulates the actions if --simulate is passed in."""
    companies = CompanyFactory.create_batch(3)
    investment_projects = InvestmentProjectFactory.create_batch(
        2,
        investor_company=companies[0],
        intermediate_company=companies[1],
        uk_company=companies[2],
        uk_company_decided=True,
    )
    file_companies = CompanyFactory.create_batch(3)

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,investor_company_id,intermediate_company_id,uk_company_id,uk_company_decided
{investment_projects[0].id},{file_companies[0].pk},{file_companies[1].pk},{file_companies[2].pk},1
{investment_projects[1].id},{file_companies[0].pk},{file_companies[1].pk},{file_companies[2].pk},1
"""
    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(bytes(csv_content, encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_investment_project_company',
                 bucket,
                 object_key,
                 simulate=True)

    for investment_project in investment_projects:
        investment_project.refresh_from_db()

    assert investment_projects[0].investor_company == companies[0]
    assert investment_projects[0].intermediate_company == companies[1]
    assert investment_projects[0].uk_company == companies[2]
    assert investment_projects[0].uk_company_decided is True
    assert investment_projects[1].investor_company == companies[0]
    assert investment_projects[1].intermediate_company == companies[1]
    assert investment_projects[1].uk_company == companies[2]
    assert investment_projects[1].uk_company_decided is True
Пример #17
0
def test_run(s3_stubber, caplog):
    """Test that the command updates the specified records (ignoring ones with errors)."""
    caplog.set_level('ERROR')

    companies = CompanyFactory.create_batch(4)
    data = []

    for company in companies:
        record = _get_data_for_company(company.id)
        raw_record = json.dumps(record)
        base64_record = b64encode(raw_record.encode('utf-8')).decode('utf-8')
        data.append(base64_record)

    wrong_json = b64encode('{"what": }'.encode('utf-8')).decode('utf-8')

    # to check that existing match will be overwritten
    company_2_match = DnBMatchingResult.objects.create(
        company_id=companies[2].id,
        data={'hello': 'world', 'confidence': 100},
    )

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,data
00000000-0000-0000-0000-000000000000,NULL
{companies[0].id},{data[0]}
{companies[1].id},{wrong_json}
{companies[2].id},{data[2]}
{companies[3].id},invalidbase64
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_company_matches', bucket, object_key)

    assert 'Company matching query does not exist' in caplog.text
    assert 'json.decoder.JSONDecodeError' in caplog.text
    assert 'binascii.Error: Invalid base64-encoded string' in caplog.text
    assert len(caplog.records) == 3

    matches = DnBMatchingResult.objects.filter(company__in=companies)
    assert matches.count() == 2

    for match in matches:
        assert match.data == _get_data_for_company(match.company_id)

    company_2_match.refresh_from_db()
    assert company_2_match.data == _get_data_for_company(company_2_match.company_id)
Пример #18
0
    def test_company_dbmodels_to_documents(self, opensearch):
        """Tests conversion of db models to OpenSearch documents."""
        companies = CompanyFactory.create_batch(2)
        app = get_search_app('company')
        companies_qs = app.queryset.all()

        result = SearchCompany.db_objects_to_documents(companies_qs)

        assert len(list(result)) == len(companies)
Пример #19
0
def test_no_change(s3_stubber, caplog):
    """Test that the command ignores records that haven't changed
    or records with incorrect current values.
    """
    caplog.set_level('WARNING')

    old_sectors = SectorFactory.create_batch(
        3,
        segment=factory.Iterator(
            ['sector_1_old', 'sector_2_old', 'sector_3_old']),
    )

    companies = CompanyFactory.create_batch(
        3,
        sector_id=factory.Iterator([sector.pk for sector in old_sectors]),
    )

    new_sectors = SectorFactory.create_batch(
        3,
        segment=factory.Iterator(
            ['sector_1_new', 'sector_2_new', 'sector_3_new']),
    )

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,old_sector_id,new_sector_id
{companies[0].pk},{old_sectors[0].pk},{new_sectors[0].pk}
{companies[1].pk},{old_sectors[1].pk},{old_sectors[1].pk}
{companies[2].pk},00000000-0000-0000-0000-000000000000,{new_sectors[2].pk}
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_company_sector_disabled_signals', bucket, object_key)

    for company in companies:
        company.refresh_from_db()

    assert f'Not updating company {companies[1]} as its sector has not changed' in caplog.text
    assert f'Not updating company {companies[2]} as its sector has not changed' in caplog.text
    assert len(caplog.records) == 2

    assert [company.sector for company in companies] == [
        new_sectors[0],
        old_sectors[1],
        old_sectors[2],
    ]
def company_names_and_postcodes(opensearch_with_collector):
    """Get companies with postcodes."""
    (names, postcodes) = zip(*(
        ('company_w1',
         'w1 2AB'),  # AB in suffix to ensure not matched in AB tests
        ('company_w1a',
         'W1A2AB'),  # AB in suffix to ensure not matched in AB tests
        ('company_w11',
         'W112AB'),  # AB in suffix to ensure not matched in AB tests
        ('company_ab1_1',
         'AB11WC'),  # WC in suffix to ensure not matched in WC tests
        ('company_ab10',
         'ab10 1WC'),  # WC in suffix to ensure not matched in WC tests
        # to test the difference between searching for AB1 0 (sector) and AB10 (district)
        ('company_ab1_0', 'AB1 0WC'),
        ('company_wc2b',
         'WC2B4AB'),  # AB in suffix to ensure not matched in AB tests
        ('company_wc2n', 'WC2N9ZZ'),
        ('company_wc1x', 'w  C   1 x0aA'),
        ('company_wc1a', 'W C 1 A 1 G A'),
        ('company_se1', 'SE13A J'),
        ('company_se1_3', 'SE13AJ'),
        ('company_se2', 'SE23AJ'),
        ('company_se3', 'SE33AJ'),
    ))

    CompanyFactory.create_batch(
        len(names),
        name=factory.Iterator(names),
        address_country_id=constants.Country.united_kingdom.value.id,
        address_postcode=factory.Iterator(postcodes),
        registered_address_country_id=constants.Country.united_kingdom.value.
        id,
        registered_address_postcode=factory.Iterator(postcodes),
    )

    CompanyFactory(
        name='non_uk_company_se1',
        address_country_id=constants.Country.united_states.value.id,
        address_postcode='SE13AJ',
    )
    opensearch_with_collector.flush_and_refresh()
Пример #21
0
    def test_one_list_download(self):
        """Test the download of the One List."""
        CompanyFactory.create_batch(
            2,
            headquarter_type_id=constants.HeadquarterType.ghq.value.id,
            classification=random_obj_for_model(CompanyClassification),
            one_list_account_owner=AdviserFactory(),
        )
        # ignored because headquarter_type is None
        CompanyFactory(
            headquarter_type=None,
            classification=random_obj_for_model(CompanyClassification),
            one_list_account_owner=AdviserFactory(),
        )
        # ignored because classification is None
        CompanyFactory(
            headquarter_type_id=constants.HeadquarterType.ghq.value.id,
            classification=None,
            one_list_account_owner=AdviserFactory(),
        )
        # ignored because one_list_account_owner is None
        CompanyFactory(
            headquarter_type_id=constants.HeadquarterType.ghq.value.id,
            classification=random_obj_for_model(CompanyClassification),
            one_list_account_owner=None,
        )

        url = reverse('admin-report:download-report',
                      kwargs={'report_id': 'one-list'})

        user = create_test_user(
            permission_codenames=('view_company', ),
            is_staff=True,
            password=self.PASSWORD,
        )

        client = self.create_client(user=user)
        response = client.get(url)
        assert response.status_code == status.HTTP_200_OK
        # 3 = header + the first 2 companies
        assert len(response.getvalue().decode('utf-8').splitlines()) == 3
Пример #22
0
def test_non_existent_company(s3_stubber, caplog):
    """Test that the command logs an error when the company PK does not exist."""
    caplog.set_level('ERROR')

    old_sectors = SectorFactory.create_batch(
        3,
        segment=factory.Iterator(
            ['sector_1_old', 'sector_2_old', 'sector_3_old']),
    )

    companies = CompanyFactory.create_batch(
        3,
        sector_id=factory.Iterator([sector.pk for sector in old_sectors]),
    )

    new_sectors = SectorFactory.create_batch(
        3,
        segment=factory.Iterator(
            ['sector_1_new', 'sector_2_new', 'sector_3_new']),
    )

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,old_sector_id,new_sector_id
{companies[0].pk},{old_sectors[0].pk},{new_sectors[0].pk}
{companies[1].pk},{old_sectors[1].pk},{new_sectors[1].pk}
00000000-0000-0000-0000-000000000000,{old_sectors[2].pk},{new_sectors[2].pk}
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_company_sector_disabled_signals', bucket, object_key)

    for company in companies:
        company.refresh_from_db()

    assert 'Company matching query does not exist' in caplog.text
    assert len(caplog.records) == 1

    assert [company.sector for company in companies] == [
        new_sectors[0],
        new_sectors[1],
        old_sectors[2],
    ]
    def test_company_search_paging(self, opensearch_with_collector):
        """
        Tests the pagination.

        The sortby is not passed in so records are ordered by id.
        """
        total_records = 9
        page_size = 2

        ids = sorted((uuid4() for _ in range(total_records)))

        name = 'test record'

        CompanyFactory.create_batch(
            len(ids),
            id=factory.Iterator(ids),
            name=name,
            trading_names=[],
        )

        opensearch_with_collector.flush_and_refresh()

        url = reverse('api-v4:search:company')
        for page in range((len(ids) + page_size - 1) // page_size):
            response = self.api_client.post(
                url,
                data={
                    'original_query': name,
                    'offset': page * page_size,
                    'limit': page_size,
                },
            )

            assert response.status_code == status.HTTP_200_OK

            start = page * page_size
            end = start + page_size
            assert [
                UUID(company['id']) for company in response.data['results']
            ] == ids[start:end]
    def test_one_list_account_manager_filter(
        self,
        num_account_managers,
        opensearch_with_collector,
    ):
        """Test one list account manager filter."""
        account_managers = AdviserFactory.create_batch(3)

        selected_account_managers = random.sample(account_managers,
                                                  num_account_managers)

        CompanyFactory.create_batch(2)
        CompanyFactory.create_batch(
            3, one_list_account_owner=factory.Iterator(account_managers))

        opensearch_with_collector.flush_and_refresh()

        query = {
            'one_list_group_global_account_manager': [
                account_manager.id
                for account_manager in selected_account_managers
            ],
        }

        url = reverse('api-v4:search:company')
        response = self.api_client.post(url, query)

        assert response.status_code == status.HTTP_200_OK

        search_results = {
            company['one_list_group_global_account_manager']['id']
            for company in response.data['results']
        }
        expected_results = {
            str(account_manager.id)
            for account_manager in selected_account_managers
        }
        assert response.data['count'] == len(selected_account_managers)
        assert len(response.data['results']) == len(selected_account_managers)
        assert search_results == expected_results
Пример #25
0
def test_simulate(s3_stubber, caplog):
    """Test that the command simulates updates if --simulate is passed in."""
    caplog.set_level('ERROR')

    export_potential_scores = [
        Company.EXPORT_POTENTIAL_SCORES.very_high,
        Company.EXPORT_POTENTIAL_SCORES.medium,
        Company.EXPORT_POTENTIAL_SCORES.low,
        Company.EXPORT_POTENTIAL_SCORES.very_high,
        Company.EXPORT_POTENTIAL_SCORES.high,
    ]
    companies = CompanyFactory.create_batch(
        5,
        export_potential=factory.Iterator(export_potential_scores),
    )

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""datahub_company_id,export_propensity
00000000-0000-0000-0000-000000000000,Low
{companies[0].pk},High
{companies[1].pk},Very high
{companies[2].pk},Low
{companies[3].pk},High
{companies[4].pk},Very high
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_company_export_potential',
                 bucket,
                 object_key,
                 simulate=True)

    for company in companies:
        company.refresh_from_db()

    assert 'Company matching query does not exist' in caplog.text
    assert len(caplog.records) == 1

    assert [company.export_potential
            for company in companies] == export_potential_scores
Пример #26
0
def test_company_subsidiaries_auto_update_to_opensearch(
        opensearch_with_signals):
    """Tests if company subsidiaries get updated in OpenSearch."""
    account_owner = AdviserFactory()
    global_headquarters = CompanyFactory(one_list_account_owner=account_owner)
    subsidiaries = CompanyFactory.create_batch(
        2, global_headquarters=global_headquarters)
    opensearch_with_signals.indices.refresh()

    subsidiary_ids = [subsidiary.id for subsidiary in subsidiaries]

    result = get_documents_by_ids(
        opensearch_with_signals,
        CompanySearchApp,
        subsidiary_ids,
    )

    expected_results = {(str(subsidiary_id), str(account_owner.id))
                        for subsidiary_id in subsidiary_ids}
    search_results = {
        (doc['_id'],
         doc['_source']['one_list_group_global_account_manager']['id'])
        for doc in result['docs']
    }

    assert len(result['docs']) == 2
    assert search_results == expected_results

    new_account_owner = AdviserFactory()
    global_headquarters.one_list_account_owner = new_account_owner
    global_headquarters.save()

    opensearch_with_signals.indices.refresh()

    new_result = get_documents_by_ids(
        opensearch_with_signals,
        CompanySearchApp,
        subsidiary_ids,
    )

    new_expected_results = {(str(subsidiary_id), str(new_account_owner.id))
                            for subsidiary_id in subsidiary_ids}
    new_search_results = {
        (doc['_id'],
         doc['_source']['one_list_group_global_account_manager']['id'])
        for doc in new_result['docs']
    }

    assert len(new_result['docs']) == 2
    assert new_search_results == new_expected_results
Пример #27
0
def test_one_list_report_generation():
    """Test the generation of the One List."""
    companies = CompanyFactory.create_batch(
        2,
        headquarter_type_id=constants.HeadquarterType.ghq.value.id,
        classification=factory.Iterator(
            CompanyClassification.objects.all(),  # keeps the ordering
        ),
        one_list_account_owner=AdviserFactory(),
    )
    # ignored because headquarter_type is None
    CompanyFactory(
        headquarter_type=None,
        classification=random_obj_for_model(CompanyClassification),
        one_list_account_owner=AdviserFactory(),
    )
    # ignored because classification is None
    CompanyFactory(
        headquarter_type_id=constants.HeadquarterType.ghq.value.id,
        classification=None,
        one_list_account_owner=AdviserFactory(),
    )
    # ignored because one_list_account_owner is None
    CompanyFactory(
        headquarter_type_id=constants.HeadquarterType.ghq.value.id,
        classification=random_obj_for_model(CompanyClassification),
        one_list_account_owner=None,
    )

    report = OneListReport()
    assert list(report.rows()) == [{
        'name':
        company.name,
        'classification__name':
        company.classification.name,
        'sector__segment':
        company.sector.segment,
        'primary_contact_name':
        company.one_list_account_owner.name,
        'one_list_account_owner__telephone_number':
        company.one_list_account_owner.telephone_number,
        'one_list_account_owner__contact_email':
        company.one_list_account_owner.contact_email,
        'registered_address_country__name':
        company.registered_address_country.name,
        'registered_address_town':
        company.registered_address_town,
        'url':
        f'{settings.DATAHUB_FRONTEND_URL_PREFIXES["company"]}/{company.id}',
    } for company in companies]
Пример #28
0
    def test_global_headquarters(self, setup_es):
        """Test global headquarters filter."""
        ghq1 = CompanyFactory(headquarter_type_id=constants.HeadquarterType.ghq.value.id)
        ghq2 = CompanyFactory(headquarter_type_id=constants.HeadquarterType.ghq.value.id)
        companies = CompanyFactory.create_batch(5, global_headquarters=ghq1)
        CompanyFactory.create_batch(5, global_headquarters=ghq2)
        CompanyFactory.create_batch(10)

        setup_es.indices.refresh()

        url = reverse('api-v3:search:company')
        response = self.api_client.post(
            url,
            {
                'global_headquarters': ghq1.id,
            },
        )
        assert response.status_code == status.HTTP_200_OK

        assert response.data['count'] == 5
        assert len(response.data['results']) == 5

        search_results = {UUID(company['id']) for company in response.data['results']}
        assert search_results == {company.id for company in companies}
Пример #29
0
def test_run(s3_stubber, caplog):
    """Test that the command updates the specified records (ignoring ones with errors)."""
    caplog.set_level('ERROR')

    company_aliases = ('abc', 'def', 'ghi', 'jkl', 'mno')

    companies = CompanyFactory.create_batch(
        5,
        alias=factory.Iterator(company_aliases),
    )

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,old_company_alias,new_company_alias
00000000-0000-0000-0000-000000000000,test,test
{companies[0].pk},{companies[0].alias},xyz100
{companies[1].pk},{companies[1].alias},xyz102
{companies[2].pk},what,xyz103
{companies[3].pk},{companies[3].alias},xyz104
{companies[4].pk},{companies[4].alias},xyz105
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_company_alias', bucket, object_key)

    for company in companies:
        company.refresh_from_db()

    assert 'Company matching query does not exist' in caplog.text
    assert len(caplog.records) == 1

    assert [company.alias for company in companies] == [
        'xyz100',
        'xyz102',
        'ghi',
        'xyz104',
        'xyz105',
    ]
def test_run(s3_stubber, caplog):
    """Test that the command updates the specified records (ignoring ones with errors)."""
    caplog.set_level('ERROR')

    original_datetime = datetime(2017, 1, 1, tzinfo=timezone.utc)

    with freeze_time(original_datetime):
        company_numbers = ['123', '456', '466879', '', None]
        companies = CompanyFactory.create_batch(
            5,
            company_number=factory.Iterator(company_numbers),
        )

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,company_number
00000000-0000-0000-0000-000000000000,123456
{companies[0].pk},012345
{companies[1].pk},456
{companies[2].pk},null
{companies[3].pk},087891
{companies[4].pk},087892
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    with freeze_time('2018-11-11 00:00:00'):
        call_command('update_company_company_number', bucket, object_key)

    for company in companies:
        company.refresh_from_db()

    assert 'Company matching query does not exist' in caplog.text
    assert len(caplog.records) == 1

    assert [company.company_number for company in companies] == [
        '012345', '456', '', '087891', '087892',
    ]
    assert all(company.modified_on == original_datetime for company in companies)