예제 #1
0
 def test_import_from_export_mixed_headers(self):
     for user in users:
         FacilityUser.objects.create(facility=self.facility, **user)
     call_command(
         "exportusers", output_file=self.csvpath, overwrite=True, demographic=True
     )
     cols_to_replace = {"Facility id": "facility", "Gender": "gender"}
     with open(self.csvpath, "r") as source:
         reader = csv.DictReader(source)
         rows = list(row for row in reader)
     with open(self.csvpath, "w") as result:
         writer = csv.DictWriter(
             result,
             tuple(
                 cols_to_replace[label] if label in cols_to_replace else label
                 for label in labels.values()
             ),
         )
         writer.writeheader()
         for row in rows:
             for col in cols_to_replace:
                 row[cols_to_replace[col]] = row[col]
                 del row[col]
             writer.writerow(row)
     FacilityUser.objects.all().delete()
     call_command("importusers", self.csvpath)
     for user in users:
         user_model = FacilityUser.objects.get(username=user["username"])
         self.assertEqual(user_model.birth_year, user["birth_year"])
         self.assertEqual(user_model.id_number, "")
예제 #2
0
 def test_import_from_export_missing_headers(self):
     for user in users:
         FacilityUser.objects.create(facility=self.facility, **user)
     call_command(
         "exportusers", output_file=self.csvpath, overwrite=True, demographic=True
     )
     cols_to_remove = ["Facility id", "Gender"]
     csv_file = open_csv_for_reading(self.csvpath)
     with csv_file as source:
         reader = csv.DictReader(source)
         rows = list(row for row in reader)
     csv_file = open_csv_for_writing(self.csvpath)
     with csv_file as result:
         writer = csv.DictWriter(
             result,
             tuple(
                 label for label in labels.values() if label not in cols_to_remove
             ),
         )
         writer.writeheader()
         for row in rows:
             for col in cols_to_remove:
                 del row[col]
             writer.writerow(row)
     FacilityUser.objects.all().delete()
     call_command("importusers", self.csvpath)
     for user in users:
         user_model = FacilityUser.objects.get(username=user["username"])
         self.assertEqual(user_model.birth_year, user["birth_year"])
         self.assertEqual(user_model.id_number, "")
예제 #3
0
    def handle(self, *args, **options):
        if options["facility"]:
            default_facility = Facility.objects.get(pk=options["facility"])
        else:
            default_facility = Facility.get_default_facility()

        if not default_facility:
            raise CommandError(
                "No default facility exists, please make sure to provision this device before running this command"
            )

        fieldnames = input_fields + tuple(val for val in labels.values())

        # open using default OS encoding
        with open(options["filepath"]) as f:
            header = next(csv.reader(f, strict=True))
            has_header = False
            if all(col in fieldnames for col in header):
                # Every item in the first row matches an item in the fieldnames, it is a header row
                if "username" not in header and str(
                        labels["username"]) not in header:
                    raise CommandError(
                        "No usernames specified, this is required for user creation"
                    )
                has_header = True
            elif any(col in fieldnames for col in header):
                raise CommandError(
                    "Mix of valid and invalid header labels found in first row"
                )

        # open using default OS encoding
        with open(options["filepath"]) as f:
            if has_header:
                reader = csv.DictReader(f, strict=True)
            else:
                reader = csv.DictReader(f,
                                        fieldnames=input_fields,
                                        strict=True)
            with transaction.atomic():
                total = 0
                for row in reader:
                    total += int(
                        create_user(map_input(row),
                                    default_facility=default_facility))
                logger.info("{total} users created".format(total=total))