예제 #1
0
def test_team_member_list_update_add_only():
    """Tests updating adding team members when none previously existed."""
    project = InvestmentProjectFactory()

    new_team_member_data = [
        {
            'investment_project': project,
            'adviser': AdviserFactory(),
            'role': 'new role',
        },
        {
            'investment_project': project,
            'adviser': AdviserFactory(),
            'role': 'new team member',
        },
    ]

    child_serializer = IProjectTeamMemberSerializer()
    serializer = IProjectTeamMemberListSerializer(child=child_serializer)

    updated_team_members = serializer.update([], new_team_member_data)

    assert updated_team_members[0].adviser == new_team_member_data[0][
        'adviser']
    assert updated_team_members[0].role == new_team_member_data[0]['role']
    assert updated_team_members[1].adviser == new_team_member_data[1][
        'adviser']
    assert updated_team_members[1].role == new_team_member_data[1]['role']
    assert project.team_members.count() == 2
예제 #2
0
    def test_restricted_user_can_update_associated_investment_project_interaction(self):
        """
        Test that a restricted user can update an interaction for an associated investment project.
        """
        project_creator = AdviserFactory()
        project = InvestmentProjectFactory(created_by=project_creator)
        interaction = CompanyInteractionFactory(
            subject='I am a subject',
            investment_project=project,
        )
        requester = create_test_user(
            permission_codenames=[
                InteractionPermission.change_associated_investmentproject,
            ],
            dit_team=project_creator.dit_team,
        )

        api_client = self.create_api_client(user=requester)
        url = reverse('api-v3:interaction:item', kwargs={'pk': interaction.pk})
        response = api_client.patch(
            url,
            data={
                'subject': 'I am another subject',
            },
        )

        assert response.status_code == status.HTTP_200_OK
        assert response.data['subject'] == 'I am another subject'
예제 #3
0
def test_audit_log(s3_stubber):
    """Test that the audit log is being created."""
    investment_project = InvestmentProjectFactory()

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,createdon
{investment_project.id},2015-09-29 11:03:20.000
"""
    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(bytes(csv_content, encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_investment_project_created_on', bucket, object_key)

    investment_project.refresh_from_db()

    assert investment_project.created_on == datetime(2015,
                                                     9,
                                                     29,
                                                     11,
                                                     3,
                                                     20,
                                                     tzinfo=utc)

    versions = Version.objects.get_for_object(investment_project)
    assert len(versions) == 1
    assert versions[0].revision.get_comment() == 'Created On migration.'
예제 #4
0
def test_no_project_code():
    """Tests that None is returned when a project code is not set."""
    # cdms_project_code is set and removed to avoid a DH project code
    # being generated
    project = InvestmentProjectFactory(cdms_project_code='P-79661656')
    project.cdms_project_code = None
    assert project.project_code is None
예제 #5
0
    def test_basic_search_no_permissions(self, setup_es):
        """Tests model permissions enforcement in basic search for a user with no permissions."""
        user = create_test_user(permission_codenames=[], dit_team=TeamFactory())
        api_client = self.create_api_client(user=user)

        InvestmentProjectFactory(created_by=user)
        CompanyFactory()
        ContactFactory()
        EventFactory()
        CompanyInteractionFactory()
        OrderFactory()

        setup_es.indices.refresh()

        url = reverse('api-v3:search:basic')
        response = api_client.get(
            url,
            data={
                'term': '',
                'entity': 'company',
            },
        )

        assert response.status_code == status.HTTP_200_OK
        response_data = response.json()
        assert response_data['count'] == 0

        assert len(response_data['aggregations']) == 0
예제 #6
0
def test_client_relationship_manager_team_valid():
    """
    Tests client_relationship_manager_team for a project with a client relationship
    manager.
    """
    project = InvestmentProjectFactory()
    assert project.client_relationship_manager_team
예제 #7
0
def test_client_relationship_manager_team_none():
    """
    Tests client_relationship_manager_team for a project without a client relationship
    manager.
    """
    project = InvestmentProjectFactory(client_relationship_manager=None)
    assert project.client_relationship_manager_team is None
예제 #8
0
def test_investor_company_country_none():
    """
    Tests client_relationship_manager_team for a project without a client relationship
    manager.
    """
    project = InvestmentProjectFactory(investor_company=None)
    assert project.investor_company_country is None
예제 #9
0
def test_validate_project_instance_success():
    """Tests validating a complete project section using a model instance."""
    project = InvestmentProjectFactory(
        client_contacts=[ContactFactory().id,
                         ContactFactory().id], )
    errors = validate(instance=project, fields=CORE_FIELDS)
    assert not errors
예제 #10
0
def test_investor_company_country_valid():
    """
    Tests client_relationship_manager_team for a project with a client relationship
    manager.
    """
    project = InvestmentProjectFactory()
    assert project.investor_company_country
예제 #11
0
def test_audit_log(s3_stubber):
    """Test that the audit log is being created."""
    investment_project = InvestmentProjectFactory()
    file_companies = CompanyFactory.create_batch(3)

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,investor_company_id,intermediate_company_id,uk_company_id,uk_company_decided
{investment_project.id},{file_companies[0].pk},{file_companies[1].pk},{file_companies[2].pk},1
"""
    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(bytes(csv_content, encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_investment_project_company', bucket, object_key)

    investment_project.refresh_from_db()

    assert investment_project.investor_company == file_companies[0]
    assert investment_project.intermediate_company == file_companies[1]
    assert investment_project.uk_company == file_companies[2]
    assert investment_project.uk_company_decided is True

    versions = Version.objects.get_for_object(investment_project)
    assert len(versions) == 1
    assert versions[0].revision.get_comment() == 'Companies data migration.'
예제 #12
0
    def test_basic_search_aggregations(self, setup_es, setup_data):
        """Tests basic aggregate query."""
        company = CompanyFactory(name='very_unique_company')
        ContactFactory(company=company)
        InvestmentProjectFactory(investor_company=company)

        setup_es.indices.refresh()

        term = 'very_unique_company'

        url = reverse('api-v3:search:basic')
        response = self.api_client.get(
            url,
            data={
                'term': term,
                'entity': 'company',
            },
        )

        assert response.status_code == status.HTTP_200_OK
        assert response.data['count'] == 1
        assert response.data['results'][0]['name'] == 'very_unique_company'

        aggregations = [
            {'count': 1, 'entity': 'company'},
            {'count': 1, 'entity': 'contact'},
            {'count': 1, 'entity': 'investment_project'},
        ]
        assert all(aggregation in response.data['aggregations'] for aggregation in aggregations)
예제 #13
0
def test_audit_log(s3_stubber):
    """Test that audit log is being created."""
    new_sector = SectorFactory()
    investment_project = InvestmentProjectFactory()
    old_sector = investment_project.sector

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,old_sector,new_sector
{investment_project.id},{old_sector.id},{new_sector.id}
"""
    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(bytes(csv_content, encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_investment_project_sector', bucket, object_key)

    investment_project.refresh_from_db()

    assert investment_project.sector == new_sector
    versions = Version.objects.get_for_object(investment_project)
    assert len(versions) == 1
    assert versions[0].revision.get_comment() == 'Sector migration.'
예제 #14
0
def test_run(s3_stubber, caplog):
    """Test that the command updates the specified records (ignoring ones with errors)."""
    caplog.set_level('ERROR')

    regions = list(UKRegion.objects.all())

    investment_projects = [
        InvestmentProjectFactory(actual_uk_regions=[]),
        InvestmentProjectFactory(actual_uk_regions=[]),
        InvestmentProjectFactory(actual_uk_regions=regions[0:1]),
        InvestmentProjectFactory(actual_uk_regions=regions[1:2]),
        InvestmentProjectFactory(actual_uk_regions=[]),
    ]

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,actual_uk_regions
00000000-0000-0000-0000-000000000000,
{investment_projects[0].pk},
{investment_projects[1].pk},{regions[2].pk}
{investment_projects[2].pk},"{regions[3].pk},{regions[4].pk}"
{investment_projects[3].pk},
{investment_projects[4].pk},"{regions[3].pk},{regions[4].pk}"
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_investment_project_actual_uk_regions', bucket, object_key)

    for project in investment_projects:
        project.refresh_from_db()

    assert 'InvestmentProject matching query does not exist' in caplog.text
    assert len(caplog.records) == 1

    assert [list(project.actual_uk_regions.all()) for project in investment_projects] == [
        [], regions[2:3], regions[0:1], regions[1:2], regions[3:5],
    ]
예제 #15
0
def test_associated_advisers_specific_roles(field):
    """Tests that get_associated_advisers() includes advisers in specific roles."""
    adviser = AdviserFactory()
    factory_kwargs = {
        field: adviser,
    }
    project = InvestmentProjectFactory(**factory_kwargs)
    assert adviser in tuple(project.get_associated_advisers())
예제 #16
0
def test_validate_reqs_competitor_countries_missing():
    """Tests missing competitor countries conditional validation."""
    project = InvestmentProjectFactory(
        stage_id=constants.InvestmentProjectStage.assign_pm.value.id,
        client_considering_other_countries=True,
    )
    errors = validate(instance=project, fields=REQUIREMENTS_FIELDS)
    assert 'competitor_countries' in errors
def test_audit_log(s3_stubber):
    """Test that reversion revisions are created."""
    regions = list(UKRegion.objects.all())

    project_without_change = InvestmentProjectFactory(
        allow_blank_possible_uk_regions=True,
        uk_region_locations=regions[0:1],
    )
    project_with_change = InvestmentProjectFactory(
        allow_blank_possible_uk_regions=False,
        uk_region_locations=[],
    )

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,allow_blank_possible_uk_regions,uk_region_locations
{project_without_change.pk},true,{regions[0].pk}
{project_with_change.pk},true,{regions[2].pk}
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command(
        'update_investment_project_possible_uk_regions',
        bucket,
        object_key,
        ignore_old_regions=True,
    )

    versions = Version.objects.get_for_object(project_without_change)
    assert versions.count() == 0

    versions = Version.objects.get_for_object(project_with_change)
    assert versions.count() == 1
    assert versions[0].revision.get_comment(
    ) == 'Possible UK regions data migration correction.'
예제 #18
0
    def test_search_investment_project_aggregates(self, setup_es):
        """Tests aggregates in investment project search."""
        url = reverse('api-v3:search:investment_project')

        InvestmentProjectFactory(
            name='Pear 1',
            stage_id=constants.InvestmentProjectStage.active.value.id,
        )
        InvestmentProjectFactory(
            name='Pear 2',
            stage_id=constants.InvestmentProjectStage.prospect.value.id,
        )
        InvestmentProjectFactory(
            name='Pear 3',
            stage_id=constants.InvestmentProjectStage.prospect.value.id,
        )
        InvestmentProjectFactory(
            name='Pear 4',
            stage_id=constants.InvestmentProjectStage.won.value.id,
        )

        setup_es.indices.refresh()

        response = self.api_client.post(
            url,
            data={
                'original_query': 'Pear',
                'stage': [
                    constants.InvestmentProjectStage.prospect.value.id,
                    constants.InvestmentProjectStage.active.value.id,
                ],
            },
        )

        assert response.status_code == status.HTTP_200_OK
        assert response.data['count'] == 3
        assert len(response.data['results']) == 3
        assert 'aggregations' in response.data

        stages = [
            {'key': constants.InvestmentProjectStage.prospect.value.id, 'doc_count': 2},
            {'key': constants.InvestmentProjectStage.active.value.id, 'doc_count': 1},
            {'key': constants.InvestmentProjectStage.won.value.id, 'doc_count': 1},
        ]
        assert all(stage in response.data['aggregations']['stage'] for stage in stages)
예제 #19
0
def test_validate_reqs_competitor_countries_present():
    """Tests required competitor countries conditional validation."""
    project = InvestmentProjectFactory(
        stage_id=constants.InvestmentProjectStage.assign_pm.value.id,
        client_considering_other_countries=True,
        competitor_countries=[constants.Country.united_states.value.id],
    )
    errors = validate(instance=project, fields=REQUIREMENTS_FIELDS)
    assert 'competitor_countries' not in errors
예제 #20
0
def test_validate_business_activity_other_instance():
    """Tests other_business_activity conditional validation for a model instance."""
    project = InvestmentProjectFactory(business_activities=[
        constants.InvestmentBusinessActivity.other.value.id
    ], )
    errors = validate(instance=project, fields=CORE_FIELDS)
    assert errors == {
        'other_business_activity': 'This field is required.',
    }
예제 #21
0
def test_validate_project_referral_website():
    """Tests referral_source_activity_website conditional validation."""
    referral_source_id = constants.ReferralSourceActivity.website.value.id
    project = InvestmentProjectFactory(
        referral_source_activity_id=referral_source_id, )
    errors = validate(instance=project, fields=CORE_FIELDS)
    assert 'referral_source_activity_website' in errors
    assert 'referral_source_activity_event' not in errors
    assert 'referral_source_activity_marketing' not in errors
예제 #22
0
def test_validate_value_instance_success():
    """Tests validating a complete value section using a model instance."""
    project = InvestmentProjectFactory(
        stage_id=constants.InvestmentProjectStage.assign_pm.value.id,
        client_cannot_provide_total_investment=False,
        total_investment=100,
        number_new_jobs=0,
    )
    errors = validate(instance=project, fields=VALUE_FIELDS)
    assert not errors
예제 #23
0
def test_validate_possible_uk_regions(allow_blank_possible_uk_regions,
                                      is_error):
    """Tests uk_region_locations (possible UK regions) conditional validation."""
    project = InvestmentProjectFactory(
        stage_id=constants.InvestmentProjectStage.assign_pm.value.id,
        allow_blank_possible_uk_regions=allow_blank_possible_uk_regions,
        uk_region_locations=[],
    )
    errors = validate(instance=project, fields=REQUIREMENTS_FIELDS)
    assert ('uk_region_locations' in errors) == is_error
예제 #24
0
def test_validate_team_instance_success():
    """Tests validating a complete team section using a model instance."""
    adviser = AdviserFactory()
    project = InvestmentProjectFactory(
        stage_id=constants.InvestmentProjectStage.active.value.id,
        project_manager=adviser,
        project_assurance_adviser=adviser,
    )
    errors = validate(instance=project, fields=TEAM_FIELDS)
    assert not errors
예제 #25
0
def test_validate_verify_win_instance_cond_validation_failure():
    """Tests conditional validation for associated non-FDI R&D projects in the verify win stage."""
    project = InvestmentProjectFactory(
        stage_id=constants.InvestmentProjectStage.verify_win.value.id,
        non_fdi_r_and_d_budget=True,
    )
    errors = validate(instance=project)
    assert 'associated_non_fdi_r_and_d_project' in errors
    assert errors[
        'associated_non_fdi_r_and_d_project'] == 'This field is required.'
예제 #26
0
def test_can_see_spi1_start(spi_report):
    """Checks if creation of Investment Project starts SPI 1"""
    investment_project = InvestmentProjectFactory()

    rows = list(spi_report.rows())

    assert len(rows) == 1
    assert rows[0][
        'Project created on'] == investment_project.created_on.isoformat()
    assert 'Enquiry processed' not in rows[0]
예제 #27
0
def test_validate_project_fail():
    """Tests validating an incomplete project section."""
    project = InvestmentProjectFactory(
        investment_type_id=constants.InvestmentType.fdi.value.id,
        fdi_type_id=None,
    )
    errors = validate(instance=project, fields=CORE_FIELDS)
    assert errors == {
        'fdi_type': 'This field is required.',
    }
예제 #28
0
def test_audit_log(s3_stubber):
    """Test that reversion revisions are created."""
    business_activities = list(InvestmentBusinessActivity.objects.all())

    project_without_change = InvestmentProjectFactory(
        business_activities=business_activities[0:1])
    project_with_change = InvestmentProjectFactory(business_activities=[])
    project_already_updated = InvestmentProjectFactory(
        business_activities=business_activities[0:1], )

    bucket = 'test_bucket'
    object_key = 'test_key'
    csv_content = f"""id,old_business_activities,new_business_activities
{project_without_change.pk},{business_activities[1].pk},{business_activities[0].pk}
{project_with_change.pk},null,{business_activities[2].pk}
{project_already_updated.pk},{business_activities[0].pk},{business_activities[0].pk}
"""

    s3_stubber.add_response(
        'get_object',
        {
            'Body': BytesIO(csv_content.encode(encoding='utf-8')),
        },
        expected_params={
            'Bucket': bucket,
            'Key': object_key,
        },
    )

    call_command('update_investment_project_business_activities', bucket,
                 object_key)

    versions = Version.objects.get_for_object(project_without_change)
    assert versions.count() == 0

    versions = Version.objects.get_for_object(project_already_updated)
    assert versions.count() == 0

    versions = Version.objects.get_for_object(project_with_change)
    assert versions.count() == 1
    assert versions[0].revision.get_comment(
    ) == 'Business activities data migration correction.'
예제 #29
0
    def test_search_sort_nested_desc(self, setup_es, setup_data):
        """Tests sorting by nested field."""
        InvestmentProjectFactory(
            name='Potato 1',
            stage_id=constants.InvestmentProjectStage.active.value.id,
        )
        InvestmentProjectFactory(
            name='Potato 2',
            stage_id=constants.InvestmentProjectStage.prospect.value.id,
        )
        InvestmentProjectFactory(
            name='potato 3',
            stage_id=constants.InvestmentProjectStage.won.value.id,
        )
        InvestmentProjectFactory(
            name='Potato 4',
            stage_id=constants.InvestmentProjectStage.won.value.id,
        )

        setup_es.indices.refresh()

        term = 'Potato'

        url = reverse('api-v3:search:investment_project')
        response = self.api_client.post(
            url,
            data={
                'original_query': term,
                'sortby': 'stage.name:desc',
            },
        )

        assert response.status_code == status.HTTP_200_OK
        assert response.data['count'] == 4
        assert [
            'Won',
            'Won',
            'Prospect',
            'Active',
        ] == [
            investment_project['stage']['name'] for investment_project in response.data['results']
        ]
예제 #30
0
def test_assigning_non_ist_project_manager_doesnt_end_spi2(spi_report):
    """Test that non IST project manager wont end SPI 2."""
    investment_project = InvestmentProjectFactory()
    # saving separately so that project_manager_first_assigned_on is updated
    investment_project.project_manager = AdviserFactory()
    investment_project.save()

    rows = list(spi_report.rows())

    assert len(rows) == 1
    assert 'Project manager assigned' not in rows[0]