class Test_archivewal_parse_arges_returns_parser_options_args(TestCase):
    def setUp(self):
        self.tempdir = TempDirectory()
        self.config_dict = {
            'General': {
                'pgsql_data_directory': self.tempdir.path,
            },
        }
        self.config_file = os.path.join(self.tempdir.path, 'config_file')
        write_config_to_filename(self.config_dict, self.config_file)

    def tearDown(self):
        self.tempdir.cleanup()

    def test_archivewal_parse_args_returns_three_items(self):
        item1, item2, item3 = archivewal_parse_args(args=['walfilename'])
        self.assertNotEqual(type(item1), type(None))
        self.assertNotEqual(type(item2), type(None))
        self.assertNotEqual(type(item3), type(None))

    def test_archivewal_parse_args_returns_parser(self):
        parser, item2, item3 = archivewal_parse_args(args=['walfilename'])
        self.assertTrue(isinstance(parser, OptionParser))

    def test_archivewal_parse_args_returns_options(self):
        item1, options, item3 = archivewal_parse_args(args=['walfilename'])
        self.assertTrue(isinstance(options, object))

    def test_archivewal_parse_args_returns_args(self):
        item1, item2, args = archivewal_parse_args(args=['walfilename'])
        self.assertEqual(type(args), type([]))
示例#2
0
    def test_orders_stop(self, name, order_data, event_data, expected):
        tempdir = TempDirectory()
        try:
            data = order_data
            data['sid'] = self.ASSET133

            order = Order(**data)

            assets = {
                133: pd.DataFrame({
                    "open": [event_data["open"]],
                    "high": [event_data["high"]],
                    "low": [event_data["low"]],
                    "close": [event_data["close"]],
                    "volume": [event_data["volume"]],
                    "dt": [pd.Timestamp('2006-01-05 14:31', tz='UTC')]
                }).set_index("dt")
            }

            write_bcolz_minute_data(
                self.env,
                pd.date_range(
                    start=normalize_date(self.minutes[0]),
                    end=normalize_date(self.minutes[-1])
                ),
                tempdir.path,
                assets
            )

            equity_minute_reader = BcolzMinuteBarReader(tempdir.path)

            data_portal = DataPortal(
                self.env,
                equity_minute_reader=equity_minute_reader,
            )

            slippage_model = VolumeShareSlippage()

            try:
                dt = pd.Timestamp('2006-01-05 14:31', tz='UTC')
                bar_data = BarData(data_portal,
                                   lambda: dt,
                                   'minute')
                _, txn = next(slippage_model.simulate(
                    bar_data,
                    self.ASSET133,
                    [order],
                ))
            except StopIteration:
                txn = None

            if expected['transaction'] is None:
                self.assertIsNone(txn)
            else:
                self.assertIsNotNone(txn)

                for key, value in expected['transaction'].items():
                    self.assertEquals(value, txn[key])
        finally:
            tempdir.cleanup()
class Test_SnapshotArchive_Repository(TestCase):
    def setUp(self):
        store = MemoryCommitStorage()
        self.repo = BBRepository(store)
        self.tempdir = TempDirectory()
        self.setup_archive_a_snapshot()

    def setup_archive_a_snapshot(self):
        archive_name = 'somearchive.tgz'
        self.archive_contents = '123'
        self.archive_path = self.tempdir.write(archive_name,
            self.archive_contents)
        self.tag = generate_tag()
        self.first_WAL = '01234'
        self.last_WAL = '45678'
        commit_snapshot_to_repository(self.repo, self.archive_path, self.tag,
            self.first_WAL, self.last_WAL)

    def tearDown(self):
        self.tempdir.cleanup()

    def test_can_retrieve_snapshot_contents_with_tag(self):
        commit = [i for i in self.repo][-1]
        restore_path = self.tempdir.getpath('restorearchive.tgz')
        commit.get_contents_to_filename(restore_path)
        self.assertEqual(self.archive_contents,
            open(restore_path, 'rb').read())

    def test_get_first_WAL_file_for_archived_snapshot_with_tag(self):
        self.assertEqual(self.first_WAL, get_first_WAL(self.repo, self.tag))

    def test_get_last_WAL_file_for_archived_snapshot_with_tag(self):
        self.assertEqual(self.last_WAL, get_last_WAL(self.repo, self.tag))
示例#4
0
class TestConfig(TestCase):

    def setUp(self):
        self.dir = TempDirectory()

    def tearDown(self):
        self.dir.cleanup()
        
    def test_table_in_multiple_sources(self):
        m1 = MetaData()
        t1 = Table('table', m1)
        m2 = MetaData()
        t2 = Table('table', m2)

        with ShouldRaise(
            ValueError("Tables present in more than one Source: table")
            ):
            Config(
                Source(t1),
                Source(t2),
                )

    def test_table_excludes(self):
        m1 = MetaData()
        t1 = Table('t1', m1)
        s1 = Source(t1)
        m2 = MetaData()
        t2 = Table('t2', m2)
        s2 = Source(t2)
        
        c = Config(s1, s2)

        compare({'t2'}, c.excludes[s1])
        compare({'t1'}, c.excludes[s2])
示例#5
0
class WithTempDir(object):

    def setUp(self):
        self.dir = TempDirectory()

    def tearDown(self):
        self.dir.cleanup()
示例#6
0
class GitHelper(object):

    repo = 'local/'

    def setUp(self):
        self.dir = TempDirectory()
        self.addCleanup(self.dir.cleanup)

    def git(self, command, repo=None):
        repo_path = self.dir.getpath(repo or self.repo)
        try:
            return check_output(['git'] + command.split(), cwd=repo_path, stderr=STDOUT)
        except CalledProcessError as e:
            self.fail(e.output)

    def git_rev_parse(self, label, repo=None):
        return self.git('rev-parse --verify -q --short '+label, repo).strip()

    def check_tags(self, expected, repo=None):
        actual = {}
        for tag in self.git('tag', repo).split():
            actual[tag] = self.git_rev_parse(tag, repo)
        compare(expected, actual=actual)

    def make_repo_with_content(self, repo):
        if not os.path.exists(self.dir.getpath(repo)):
            self.dir.makedir(repo)
        self.git('init', repo)
        self.dir.write(repo + 'a', 'some content')
        self.dir.write(repo + 'b', 'other content')
        self.dir.write(repo + 'c', 'more content')
        self.git('add .', repo)
        self.git('commit -m initial', repo)
示例#7
0
class TestPathSource(TestCase):

    def setUp(self):
        self.dir = TempDirectory()
        self.addCleanup(self.dir.cleanup)

    def test_abc(self):
        self.assertTrue(issubclass(Plugin, Source))

    def test_schema_ok(self):
        p1 = self.dir.write('foo', b'f')
        p2 = self.dir.write('bar', b'b')
        compare(
            dict(type='paths', values=[p1, p2], repo='config'),
            Plugin.schema(
                dict(type='paths', values=[p1, p2], repo='config')
            ))

    def test_schema_wrong_type(self):
        text = "not a valid value for dictionary value @ data['type']"
        with ShouldFailSchemaWith(text):
            Plugin.schema(dict(type='bar', values=['/']))

    def test_schema_extra_keys(self):
        with ShouldFailSchemaWith("extra keys not allowed @ data['foo']"):
            Plugin.schema(dict(type='paths', foo='bar'))

    def test_name_supplied(self):
        text = "not a valid value for dictionary value @ data['name']"
        with ShouldFailSchemaWith(text):
            Plugin.schema(dict(type='paths', name='foo'))

    def test_no_paths(self):
        text = "length of value must be at least 1 for dictionary value " \
               "@ data['values']"
        with ShouldFailSchemaWith(text):
            Plugin.schema(dict(type='paths', values=[]))

    def test_path_not_string(self):
        text = "invalid list value @ data['values'][0]"
        with ShouldFailSchemaWith(text):
            Plugin.schema(dict(type='paths', values=[1]))

    def test_path_not_starting_with_slash(self):
        text = "invalid list value @ data['values'][0]"
        with ShouldFailSchemaWith(text):
            Plugin.schema(dict(type='paths', values=['foo']))

    def test_path_not_there(self):
        text = "invalid list value @ data['values'][0]"
        with ShouldFailSchemaWith(text):
            Plugin.schema(dict(type='paths', values=[self.dir.getpath('bad')]))

    def test_interface(self):
        plugin = Plugin('source', name=None, repo='config',
                        values=['/foo/bar'])
        compare(plugin.type, 'source')
        compare(plugin.name, None)
        compare(plugin.repo, 'config')
        compare(plugin.source_paths, ['/foo/bar'])
 def test_cleanup(self):
     d = TempDirectory()
     p = d.path
     assert os.path.exists(p) is True
     p = d.write('something', b'stuff')
     d.cleanup()
     assert os.path.exists(p) is False
    def test_activity(self, fake_session_mock, fake_s3_mock, fake_key_mock,
                      mock_sqs_message, mock_sqs_connect):
        directory = TempDirectory()

        for testdata in self.do_activity_passes:

            fake_session_mock.return_value = FakeSession(test_data.PreparePost_session_example(
                testdata["update_date"]))
            mock_sqs_connect.return_value = FakeSQSConn(directory)
            mock_sqs_message.return_value = FakeSQSMessage(directory)
            fake_s3_mock.return_value = FakeS3Connection()
            self.activity_PreparePostEIF.logger = mock.MagicMock()
            self.activity_PreparePostEIF.set_monitor_property = mock.MagicMock()
            self.activity_PreparePostEIF.emit_monitor_event = mock.MagicMock()

            success = self.activity_PreparePostEIF.do_activity(test_data.PreparePostEIF_data)

            fake_sqs_queue = FakeSQSQueue(directory)
            data_written_in_test_queue = fake_sqs_queue.read(test_data.PreparePostEIF_test_dir)

            self.assertEqual(True, success)
            self.assertEqual(json.dumps(testdata["message"]), data_written_in_test_queue)

            output_json = json.loads(directory.read(test_data.PreparePostEIF_test_dir))
            expected = testdata["expected"]
            self.assertDictEqual(output_json, expected)
class Test_incorrect_invocation(TestCase):
    mainMsg = '''You have invoked this script as bbpgsql.
This script is supposed to be invoked through the commands archivepgsql
and archivewal.  Please check with your adminstrator to make sure these
commands were installed correctly.
'''
    unknownMsg = 'Unknown command: unknown\n'

    def setUp(self):
        self.tempdir = TempDirectory()
        self.config_dict = {
        }
        self.config_path = os.path.join(self.tempdir.path, 'config.ini')
        write_config_to_filename(self.config_dict, self.config_path)

    def tearDown(self):
        self.tempdir.cleanup()

    @patch('bbpgsql.bbpgsql_main.exit')
    def test_invocation_using_main_script_fails(self,
        mock_exit):
        bbpgsql_main(['bbpgsql', '-c', self.config_path])
        mock_exit.assert_called_with(1)

    @patch('bbpgsql.bbpgsql_main.stdout.write')
    @patch('bbpgsql.bbpgsql_main.exit')
    def test_invocation_using_unknown_fails(self,
        mock_exit, mock_stdout_write):
        bbpgsql_main(['unknown'])
        mock_stdout_write.assert_called_once_with(self.unknownMsg)
        mock_exit.assert_called_once_with(1)
class Test_archivewal_requires_WAL_file(TestCase):
    def setUp(self):
        self.tempdir = TempDirectory()
        self.config_dict = {
            'General': {
                'pgsql_data_directory': self.tempdir.path,
            },
        }
        self.config_file = os.path.join(self.tempdir.path, 'config_file')
        write_config_to_filename(self.config_dict, self.config_file)
        parser, self.options, self.args = archivewal_parse_args(['-c',
            self.config_file])

    def tearDown(self):
        self.tempdir.cleanup()

    def test_will_raise_exception_with_no_WAL_file(self):

        def will_raise_Exception():
            archivewal_validate_options_and_args(self.options, [])
        self.assertRaises(Exception, will_raise_Exception)

    def test_exception_is_explicit_about_error(self):
        try:
            archivewal_validate_options_and_args(self.options, [])
        except Exception, e:
            print 'Exception', e
            self.assertTrue('path to a WAL file' in str(e))
        else:
示例#12
0
class Test_archivepgsql_backup_invocation(TestCase):
    ARCHIVEPGSQL_PATH = os.path.join('bbpgsql', 'cmdline_scripts')
    CONFIG_FILE = 'config.ini'
    exe_script = 'archivepgsql'

    def setUp(self):
        self.setup_environment()
        self.setup_config()
        self.execution_sequence = 0

    def setup_environment(self):
        self.env = deepcopy(os.environ)
        self.env['PATH'] = ''.join([
            self.env['PATH'],
            ':',
            self.ARCHIVEPGSQL_PATH])
        self.tempdir = TempDirectory()
        self.data_dir = self.tempdir.makedir('pgsql_data')
        self.archive_dir = self.tempdir.makedir('pgsql_archive')

    def setup_config(self):
        self.config_path = os.path.join(self.tempdir.path, self.CONFIG_FILE)
        self.config_dict = {
            'General': {
                'pgsql_data_directory': self.data_dir,
            },
            'Snapshot': {
                'driver': 'memory',
            },
        }
        write_config_to_filename(self.config_dict, self.config_path)
        self.config = get_config_from_filename_and_set_up_logging(
            self.config_path
        )

    def tearDown(self):
        self.tempdir.cleanup()

    @patch('bbpgsql.archive_pgsql.commit_snapshot_to_repository')
    @patch('bbpgsql.archive_pgsql.create_archive')
    @patch('bbpgsql.archive_pgsql.pg_stop_backup')
    @patch('bbpgsql.archive_pgsql.pg_start_backup')
    def test_perform_backup(self, mock_pg_start_backup, mock_pg_stop_backup,
        mock_create_archive, mock_commit_snapshot_to_repository):
        first_WAL = '000000D0'
        second_WAL = '000000D1'
        mock_pg_start_backup.return_value = first_WAL
        mock_pg_stop_backup.return_value = second_WAL
        archiveFile = os.path.join(self.archive_dir, 'pgsql.snapshot.tar')
        tag = bbpgsql.archive_pgsql.generate_tag()
        repo = get_Snapshot_repository(self.config)
        bbpgsql.archive_pgsql.perform_backup(self.data_dir,
            archiveFile, tag, repo)
        mock_pg_start_backup.assert_called_once_with(tag)
        mock_create_archive.assert_called_once_with(self.data_dir, archiveFile)
        self.assertEqual(mock_pg_stop_backup.called, True)
        self.assertEqual(mock_pg_stop_backup.call_count, 1)
        mock_commit_snapshot_to_repository.assert_called_once_with(
            repo, archiveFile, tag, first_WAL, second_WAL)
示例#13
0
 def test_evaluate_read_same(self):
     dir = TempDirectory()
     dir.write('foo', b'content')
     d = TestContainer('parsed',FileBlock('foo','content','read'))
     d.evaluate_with(Files('td'),globs={'td':dir})
     compare([C(FileResult,
                passed=True,
                expected=None,
                actual=None)],
             [r.evaluated for r in d])
示例#14
0
class MailTestCaseMixin(TestCase):

    def _pre_setup(self):
        super(MailTestCaseMixin, self)._pre_setup()
        self.tempdir = TempDirectory()
        self.settings_override = override_settings(
            MEDIA_ROOT=self.tempdir.path,
            EMAIL_BACKEND=u'poleno.mail.backend.EmailBackend',
            )
        self.settings_override.enable()

    def _post_teardown(self):
        self.settings_override.disable()
        self.tempdir.cleanup()
        super(MailTestCaseMixin, self)._post_teardown()


    def _call_with_defaults(self, func, kwargs, defaults):
        omit = kwargs.pop(u'omit', [])
        defaults.update(kwargs)
        for key in omit:
            defaults.pop(key, None)
        return func(**defaults)

    def _create_attachment(self, **kwargs):
        content = kwargs.pop(u'content', u'Default Testing Content')
        return self._call_with_defaults(Attachment.objects.create, kwargs, {
            u'file': ContentFile(content, name=u'overriden-file-name.bin'),
            u'name': u'default_testing_filename.txt',
            u'content_type': u'text/plain',
            })

    def _create_recipient(self, **kwargs):
        return self._call_with_defaults(Recipient.objects.create, kwargs, {
            u'name': u'Default Testing Name',
            u'mail': u'*****@*****.**',
            u'type': Recipient.TYPES.TO,
            u'status': Recipient.STATUSES.INBOUND,
            u'status_details': u'',
            u'remote_id': u'',
            })

    def _create_message(self, **kwargs):
        return self._call_with_defaults(Message.objects.create, kwargs, {
            u'type': Message.TYPES.INBOUND,
            u'processed': utc_now(),
            u'from_name': u'Default Testing From Name',
            u'from_mail': u'*****@*****.**',
            u'received_for': u'*****@*****.**',
            u'subject': u'Default Testing Subject',
            u'text': u'Default Testing Text Content',
            u'html': u'<p>Default Testing HTML Content</p>',
            u'headers': {'X-Default-Testing-Extra-Header': 'Default Testing Value'},
            })
示例#15
0
    def test_inifile_discovery_should_ignore_invalid_files_without_raising_exception(self):
        root_dir = TempDirectory()
        self.tmpdirs.append(root_dir)

        cfg_file = root_dir.write(('some', 'strange', 'config.cfg'), '&ˆ%$#$%ˆ&*()(*&ˆ'.encode('utf8'))
        root_dir.write(('some', 'config.ini'), '$#%ˆ&*((*&ˆ%'.encode('utf8'))

        discovery = ConfigurationDiscovery(
            os.path.realpath(os.path.dirname(cfg_file)), filetypes=(IniFileConfigurationLoader, ))

        self.assertEqual(discovery.config_files,  [])
示例#16
0
    def test_create_monitor_with_watch_path(self):
        wd = TempDirectory()
        source_path = wd.makedir('source')
        wd.makedir('build')

        with chdir(wd.path):
            m = create_monitor(source_path)
        assert len(m.reporters) == 1
        reporter = m.reporters[0]
        assert reporter.watch_path == source_path
        assert reporter.build_path == '{}-build'.format(os.path.realpath(source_path)) # noqa
示例#17
0
 def test_evaulate_write(self):
     dir = TempDirectory()
     d = TestContainer('parsed',FileBlock('foo','content','write'))
     d.evaluate_with(Files('td'),globs={'td':dir})
     compare([C(FileResult,
                passed=True,
                expected=None,
                actual=None)],
             [r.evaluated for r in d])
     dir.compare(['foo'])
     compare(dir.read('foo', 'ascii'), 'content')
示例#18
0
 def test_evaluate_read_difference(self):
     dir = TempDirectory()
     dir.write('foo', b'actual')
     d = TestContainer('parsed',FileBlock('foo','expected','read'))
     d.evaluate_with(Files('td'),globs={'td':dir})
     compare([C(FileResult,
                passed=False,
                path='foo',
                expected='expected',
                actual='actual')],
             [r.evaluated for r in d])
 def tearDown(self):
     """
     Delete all the temporary directories and files created during this
     testing session.
     """
     for replicate in range(1, 6):
         file_name = "original_mutated_%d.fasta" % replicate
         if os.path.exists(file_name):
             os.remove(file_name)
     if os.path.exists("original_snpListMutated.txt"):
         os.remove("original_snpListMutated.txt")
     TempDirectory.cleanup_all()
    def test_cleanup_all(self):
        d1 = TempDirectory()
        d2 = TempDirectory()

        assert os.path.exists(d1.path) is True
        p1 = d1.path
        assert os.path.exists(d2.path) is True
        p2 = d2.path

        TempDirectory.cleanup_all()

        assert os.path.exists(p1) is False
        assert os.path.exists(p2) is False
    def test_activity(self, mock_set_monitor_property, mock_emit_monitor_event, fake_s3_mock, fake_key_mock):
        directory = TempDirectory()

        fake_key_mock.return_value = FakeKey(directory, data.bucket_dest_file_name,
                                             data.RewriteEIF_json_input_string)
        fake_s3_mock.return_value = FakeS3Connection()

        success = self.activity_PreparePostEIF.do_activity(data.RewriteEIF_data)
        self.assertEqual(True, success)

        output_json = json.loads(directory.read(data.bucket_dest_file_name))
        expected = data.RewriteEIF_json_output
        self.assertDictEqual(output_json, expected)
示例#22
0
    def test_use_configuration_from_root_path_when_no_other_was_found(self):
        root_dir = TempDirectory()
        self.tmpdirs.append(root_dir)

        start_path = root_dir.makedir('some/directories/to/start/looking/for/settings')
        test_file = os.path.realpath(os.path.join(root_dir.path, 'settings.ini'))
        with open(test_file, 'a') as file_:
            file_.write('[settings]')
        self.files.append(test_file)  # Required to removed it at tearDown

        discovery = ConfigurationDiscovery(start_path, root_path=root_dir.path)
        filenames = [cfg.filename for cfg in discovery.config_files]
        self.assertEqual([test_file], filenames)
示例#23
0
class Test_OptionParsing_and_Validation(TestCase):
    def setUp(self):
        self.tempdir = TempDirectory()
        self.config_dict = {
        }
        self.config_path = os.path.join(self.tempdir.path, 'config.ini')
        write_config_to_filename(self.config_dict, self.config_path)

    def tearDown(self):
        self.tempdir.cleanup()

    def test_non_destructive_with_sys_argv(self):
        expected_sys_argv = ['', '-c', self.config_path]
        sys.argv = expected_sys_argv[:]
        non_destructive_minimal_parse_and_validate_args()
        self.assertEqual(expected_sys_argv, sys.argv)

    def test_validation_raises_exception_if_config_file_does_not_exist(self):
        def validate():
            parser, options, args = common_parse_args(args=[
                '--config', '/tmp/blah/blah/bbpgsql.ini'])
            common_validate_options_and_args(options, args)
        self.assertRaises(Exception, validate)

    def test_validation_raises_exception_if_config_file_permissions_too_open(
        self):
        with TempDirectory() as d:
            self.parent_dir = d.makedir('parent_dir')
            self.config_path = d.write('parent_dir/config.ini', '')
            self.open_perm = stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO
            os.chmod(self.config_path, self.open_perm)

            def validate(config_path):
                parser, options, args = common_parse_args(args=[
                    '--config', config_path])
                common_validate_options_and_args(options, args)
            self.assertRaises(Exception, validate, self.config_path)

    def test_options_validate_if_config_file_exists(self):
        parser, options, args = common_parse_args(args=[
            '--config', self.config_path])
        self.assertTrue(common_validate_options_and_args(options, args))

    def test_validation_raises_exception_if_cannot_read_config_file(self):
        def validate():
            parser, options, args = common_parse_args(args=[
                '--config', self.config_path])
            self.no_perm = 0
            os.chmod(self.config_path, self.no_perm)
            common_validate_options_and_args(options, args)
        self.assertRaises(Exception, validate)
class Test_FilesystemCommitStorage(TestCase):
    def setUp(self):
        self.tempdir = TempDirectory()
        self.commit_storage_path = self.tempdir.makedir('commit_storage')
        self.config = config()
        self.config.set('WAL', 'driver', 'filesystem')
        self.config.set('WAL', 'path', self.commit_storage_path)

    def tearDown(self):
        self.tempdir.cleanup()

    def test_will_build_storage_from_config(self):
        self.assertEqual(FilesystemCommitStorage,
            type(get_repository_storage_from_config(self.config, 'WAL')))
示例#25
0
class copySpecialTestCase(unittest.TestCase):

    def setUp(self):
        self.inputDirectory = []
        self.specialFilesNamesAbsolutePath = []
        self.specialFileNames =[]
        self.inputDirectoriesPath = []
        for i in range(0,2):
            self.inputDirectory.append(TempDirectory())
            self.inputDirectory[i].write('india__{0}{0}{0}__.txt'.format(i), 'some foo thing')
            self.inputDirectory[i].write('god__{0}{0}{0}__.txt'.format(i), 'some foo thing')
            self.inputDirectory[i].write('hi.txt', 'some foo thing')

            self.specialFilesNamesAbsolutePath.append('{0}/india__{1}{1}{1}__.txt'.format(self.inputDirectory[i].path, i))
            self.specialFilesNamesAbsolutePath.append('{0}/god__{1}{1}{1}__.txt'.format(self.inputDirectory[i].path, i))
            self.specialFileNames.append('india__{0}{0}{0}__.txt'.format(i))
            self.specialFileNames.append('god__{0}{0}{0}__.txt'.format(i))
            self.inputDirectoriesPath.append(self.inputDirectory[i].path)

        self.outputDirectory = TempDirectory()
        self.outputDirectory.path = self.outputDirectory.path

    def test_getSpecialFilesNames(self):
        copySpecialFiles = CopySpecialFiles()
        for path in self.inputDirectoriesPath:
            copySpecialFiles.specialFileNames.extend(copySpecialFiles.getSpecialFilesNames(path))
        self.assertEqual(sorted(self.specialFilesNamesAbsolutePath), sorted(copySpecialFiles.specialFileNames))

    def test_copyToDirectory(self):
        copySpecialFiles = CopySpecialFiles()
        for path in self.inputDirectoriesPath:
            copySpecialFiles.specialFileNames.extend(copySpecialFiles.getSpecialFilesNames(path))
        copySpecialFiles.outputDirectory = self.outputDirectory.path
        copySpecialFiles.copyToDirectory()
        self.assertEqual(sorted(os.listdir(self.outputDirectory.path)), sorted(self.specialFileNames))

    def test_copyToZip(self):
        copySpecialFiles = CopySpecialFiles()
        for path in self.inputDirectoriesPath:
            copySpecialFiles.specialFileNames.extend(copySpecialFiles.getSpecialFilesNames(path))
        copySpecialFiles.zipFile = "dummy.zip"
        copySpecialFiles.copyToZip()
        with zipfile.ZipFile('dummy.zip', 'r') as myzip:
            self.assertEqual(sorted(myzip.namelist()), sorted(self.specialFileNames))
        os.remove("dummy.zip")

    def tearDown(self):
        for inputDirectory in self.inputDirectory:
            inputDirectory.cleanup()
        self.outputDirectory.cleanup()
 def test_dont_cleanup_with_path(self):
     d = mkdtemp()
     fp = os.path.join(d, "test")
     with open(fp, "w") as f:
         f.write("foo")
     try:
         td = TempDirectory(path=d)
         self.assertEqual(d, td.path)
         td.cleanup()
         # checks
         self.assertEqual(os.listdir(d), ["test"])
         with file(fp) as f:
             self.assertEqual(f.read(), "foo")
     finally:
         rmtree(d)
 def test_dont_cleanup_with_path(self):
     d = mkdtemp()
     fp = os.path.join(d, 'test')
     with open(fp, 'w') as f:
         f.write('foo')
     try:
         td = TempDirectory(path=d)
         self.assertEqual(d, td.path)
         td.cleanup()
         # checks
         self.assertEqual(os.listdir(d), ['test'])
         with open(fp) as f:
             self.assertEqual(f.read(), 'foo')
     finally:
         rmtree(d)
示例#28
0
 def setUp(self):
     self.dir = TempDirectory()
     self.cleanups.extend([self.dir.cleanup,self.removeAtExit])
     # This one is so that we can check what log handlers
     # get added.
     self.handlers = []
     self.r.replace('checker.logger.handlers',self.handlers)
示例#29
0
 def setUp(self):
     self.tempdir = TempDirectory()
     self.tempdir.write('a_004', b'some text a4')
     self.tempdir.write('a_005', b'some text a5')
     self.tempdir.write('b_002.txt', b'some text b2')
     self.tempdir.write('b_008.out', b'some text b8')
     self.tempdir.write(('c_010', 'por'), b'some text c5por')
示例#30
0
 def setup_environment(self):
     self.env = deepcopy(os.environ)
     self.env['PATH'] = ''.join([
         self.env['PATH'],
         ':',
         self.ARCHIVEPGSQL_PATH])
     self.tempdir = TempDirectory()
示例#31
0
 def setUp(self):
     self.dir = TempDirectory()
     self.energy = [7.5, 8]
 def test_files_only(self):
     with TempDirectory() as d:
         d.write('a/b/c', b'')
         d.compare(['a/b/c'], files_only=True)
 def test_read_no_decode(self):
     with TempDirectory() as d:
         with open(os.path.join(d.path, 'test.file'), 'wb') as f:
             f.write(b'\xc2\xa3')
         compare(d.read('test.file'), b'\xc2\xa3')
示例#34
0
def tmp_dir():
    with TempDirectory() as tmpdir:
        yield tmpdir
def tearDown(test):
    TempDirectory.cleanup_all()
示例#36
0
class BcolzMinuteBarTestCase(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.env = TradingEnvironment()
        all_market_opens = cls.env.open_and_closes.market_open
        all_market_closes = cls.env.open_and_closes.market_close
        indexer = all_market_opens.index.slice_indexer(
            start=TEST_CALENDAR_START, end=TEST_CALENDAR_STOP)
        cls.market_opens = all_market_opens[indexer]
        cls.market_closes = all_market_closes[indexer]
        cls.test_calendar_start = cls.market_opens.index[0]
        cls.test_calendar_stop = cls.market_opens.index[-1]

    def setUp(self):

        self.dir_ = TempDirectory()
        self.dir_.create()
        self.dest = self.dir_.getpath('minute_bars')
        os.makedirs(self.dest)
        self.writer = BcolzMinuteBarWriter(
            TEST_CALENDAR_START,
            self.dest,
            self.market_opens,
            self.market_closes,
            US_EQUITIES_MINUTES_PER_DAY,
        )
        self.reader = BcolzMinuteBarReader(self.dest)

    def tearDown(self):
        self.dir_.cleanup()

    def test_write_one_ohlcv(self):
        minute = self.market_opens[self.test_calendar_start]
        sid = 1
        data = DataFrame(data={
            'open': [10.0],
            'high': [20.0],
            'low': [30.0],
            'close': [40.0],
            'volume': [50.0]
        },
                         index=[minute])
        self.writer.write(sid, data)

        open_price = self.reader.get_value(sid, minute, 'open')

        self.assertEquals(10.0, open_price)

        high_price = self.reader.get_value(sid, minute, 'high')

        self.assertEquals(20.0, high_price)

        low_price = self.reader.get_value(sid, minute, 'low')

        self.assertEquals(30.0, low_price)

        close_price = self.reader.get_value(sid, minute, 'close')

        self.assertEquals(40.0, close_price)

        volume_price = self.reader.get_value(sid, minute, 'volume')

        self.assertEquals(50.0, volume_price)

    def test_write_two_bars(self):
        minute_0 = self.market_opens[self.test_calendar_start]
        minute_1 = minute_0 + timedelta(minutes=1)
        sid = 1
        data = DataFrame(data={
            'open': [10.0, 11.0],
            'high': [20.0, 21.0],
            'low': [30.0, 31.0],
            'close': [40.0, 41.0],
            'volume': [50.0, 51.0]
        },
                         index=[minute_0, minute_1])
        self.writer.write(sid, data)

        open_price = self.reader.get_value(sid, minute_0, 'open')

        self.assertEquals(10.0, open_price)

        high_price = self.reader.get_value(sid, minute_0, 'high')

        self.assertEquals(20.0, high_price)

        low_price = self.reader.get_value(sid, minute_0, 'low')

        self.assertEquals(30.0, low_price)

        close_price = self.reader.get_value(sid, minute_0, 'close')

        self.assertEquals(40.0, close_price)

        volume_price = self.reader.get_value(sid, minute_0, 'volume')

        self.assertEquals(50.0, volume_price)

        open_price = self.reader.get_value(sid, minute_1, 'open')

        self.assertEquals(11.0, open_price)

        high_price = self.reader.get_value(sid, minute_1, 'high')

        self.assertEquals(21.0, high_price)

        low_price = self.reader.get_value(sid, minute_1, 'low')

        self.assertEquals(31.0, low_price)

        close_price = self.reader.get_value(sid, minute_1, 'close')

        self.assertEquals(41.0, close_price)

        volume_price = self.reader.get_value(sid, minute_1, 'volume')

        self.assertEquals(51.0, volume_price)

    def test_write_on_second_day(self):
        second_day = self.test_calendar_start + 1
        minute = self.market_opens[second_day]
        sid = 1
        data = DataFrame(data={
            'open': [10.0],
            'high': [20.0],
            'low': [30.0],
            'close': [40.0],
            'volume': [50.0]
        },
                         index=[minute])
        self.writer.write(sid, data)

        open_price = self.reader.get_value(sid, minute, 'open')

        self.assertEquals(10.0, open_price)

        high_price = self.reader.get_value(sid, minute, 'high')

        self.assertEquals(20.0, high_price)

        low_price = self.reader.get_value(sid, minute, 'low')

        self.assertEquals(30.0, low_price)

        close_price = self.reader.get_value(sid, minute, 'close')

        self.assertEquals(40.0, close_price)

        volume_price = self.reader.get_value(sid, minute, 'volume')

        self.assertEquals(50.0, volume_price)

    def test_write_empty(self):
        minute = self.market_opens[self.test_calendar_start]
        sid = 1
        data = DataFrame(data={
            'open': [0],
            'high': [0],
            'low': [0],
            'close': [0],
            'volume': [0]
        },
                         index=[minute])
        self.writer.write(sid, data)

        open_price = self.reader.get_value(sid, minute, 'open')

        assert_almost_equal(nan, open_price)

        high_price = self.reader.get_value(sid, minute, 'high')

        assert_almost_equal(nan, high_price)

        low_price = self.reader.get_value(sid, minute, 'low')

        assert_almost_equal(nan, low_price)

        close_price = self.reader.get_value(sid, minute, 'close')

        assert_almost_equal(nan, close_price)

        volume_price = self.reader.get_value(sid, minute, 'volume')

        assert_almost_equal(0, volume_price)

    def test_write_on_multiple_days(self):

        tds = self.market_opens.index
        days = tds[tds.slice_indexer(start=self.test_calendar_start + 1,
                                     end=self.test_calendar_start + 3)]
        minutes = DatetimeIndex([
            self.market_opens[days[0]] + timedelta(minutes=60),
            self.market_opens[days[1]] + timedelta(minutes=120),
        ])
        sid = 1
        data = DataFrame(data={
            'open': [10.0, 11.0],
            'high': [20.0, 21.0],
            'low': [30.0, 31.0],
            'close': [40.0, 41.0],
            'volume': [50.0, 51.0]
        },
                         index=minutes)
        self.writer.write(sid, data)

        minute = minutes[0]

        open_price = self.reader.get_value(sid, minute, 'open')

        self.assertEquals(10.0, open_price)

        high_price = self.reader.get_value(sid, minute, 'high')

        self.assertEquals(20.0, high_price)

        low_price = self.reader.get_value(sid, minute, 'low')

        self.assertEquals(30.0, low_price)

        close_price = self.reader.get_value(sid, minute, 'close')

        self.assertEquals(40.0, close_price)

        volume_price = self.reader.get_value(sid, minute, 'volume')

        self.assertEquals(50.0, volume_price)

        minute = minutes[1]

        open_price = self.reader.get_value(sid, minute, 'open')

        self.assertEquals(11.0, open_price)

        high_price = self.reader.get_value(sid, minute, 'high')

        self.assertEquals(21.0, high_price)

        low_price = self.reader.get_value(sid, minute, 'low')

        self.assertEquals(31.0, low_price)

        close_price = self.reader.get_value(sid, minute, 'close')

        self.assertEquals(41.0, close_price)

        volume_price = self.reader.get_value(sid, minute, 'volume')

        self.assertEquals(51.0, volume_price)

    def test_no_overwrite(self):
        minute = self.market_opens[TEST_CALENDAR_START]
        sid = 1
        data = DataFrame(data={
            'open': [10.0],
            'high': [20.0],
            'low': [30.0],
            'close': [40.0],
            'volume': [50.0]
        },
                         index=[minute])
        self.writer.write(sid, data)

        with self.assertRaises(BcolzMinuteOverlappingData):
            self.writer.write(sid, data)

    def test_write_multiple_sids(self):
        """
        Test writing multiple sids.

        Tests both that the data is written to the correct sid, as well as
        ensuring that the logic for creating the subdirectory path to each sid
        does not cause issues from attempts to recreate existing paths.
        (Calling out this coverage, because an assertion of that logic does not
        show up in the test itself, but is exercised by the act of attempting
        to write two consecutive sids, which would be written to the same
        containing directory, `00/00/000001.bcolz` and `00/00/000002.bcolz)

        Before applying a check to make sure the path writing did not
        re-attempt directory creation an OSError like the following would
        occur:

        ```
        OSError: [Errno 17] File exists: '/tmp/tmpR7yzzT/minute_bars/00/00'
        ```
        """
        minute = self.market_opens[TEST_CALENDAR_START]
        sids = [1, 2]
        data = DataFrame(data={
            'open': [15.0],
            'high': [17.0],
            'low': [11.0],
            'close': [15.0],
            'volume': [100.0]
        },
                         index=[minute])
        self.writer.write(sids[0], data)

        data = DataFrame(data={
            'open': [25.0],
            'high': [27.0],
            'low': [21.0],
            'close': [25.0],
            'volume': [200.0]
        },
                         index=[minute])
        self.writer.write(sids[1], data)

        sid = sids[0]

        open_price = self.reader.get_value(sid, minute, 'open')

        self.assertEquals(15.0, open_price)

        high_price = self.reader.get_value(sid, minute, 'high')

        self.assertEquals(17.0, high_price)

        low_price = self.reader.get_value(sid, minute, 'low')

        self.assertEquals(11.0, low_price)

        close_price = self.reader.get_value(sid, minute, 'close')

        self.assertEquals(15.0, close_price)

        volume_price = self.reader.get_value(sid, minute, 'volume')

        self.assertEquals(100.0, volume_price)

        sid = sids[1]

        open_price = self.reader.get_value(sid, minute, 'open')

        self.assertEquals(25.0, open_price)

        high_price = self.reader.get_value(sid, minute, 'high')

        self.assertEquals(27.0, high_price)

        low_price = self.reader.get_value(sid, minute, 'low')

        self.assertEquals(21.0, low_price)

        close_price = self.reader.get_value(sid, minute, 'close')

        self.assertEquals(25.0, close_price)

        volume_price = self.reader.get_value(sid, minute, 'volume')

        self.assertEquals(200.0, volume_price)

    def test_pad_data(self):
        """
        Test writing empty data.
        """
        sid = 1
        last_date = self.writer.last_date_in_output_for_sid(sid)
        self.assertIs(last_date, NaT)

        self.writer.pad(sid, TEST_CALENDAR_START)

        last_date = self.writer.last_date_in_output_for_sid(sid)
        self.assertEqual(last_date, TEST_CALENDAR_START)

        freq = self.market_opens.index.freq
        minute = self.market_opens[TEST_CALENDAR_START + freq]
        data = DataFrame(data={
            'open': [15.0],
            'high': [17.0],
            'low': [11.0],
            'close': [15.0],
            'volume': [100.0]
        },
                         index=[minute])
        self.writer.write(sid, data)

        open_price = self.reader.get_value(sid, minute, 'open')

        self.assertEquals(15.0, open_price)

        high_price = self.reader.get_value(sid, minute, 'high')

        self.assertEquals(17.0, high_price)

        low_price = self.reader.get_value(sid, minute, 'low')

        self.assertEquals(11.0, low_price)

        close_price = self.reader.get_value(sid, minute, 'close')

        self.assertEquals(15.0, close_price)

        volume_price = self.reader.get_value(sid, minute, 'volume')

        self.assertEquals(100.0, volume_price)

    def test_write_cols(self):
        minute_0 = self.market_opens[self.test_calendar_start]
        minute_1 = minute_0 + timedelta(minutes=1)
        sid = 1
        cols = {
            'open': array([10.0, 11.0]),
            'high': array([20.0, 21.0]),
            'low': array([30.0, 31.0]),
            'close': array([40.0, 41.0]),
            'volume': array([50.0, 51.0])
        }
        dts = array([minute_0, minute_1], dtype='datetime64[s]')
        self.writer.write_cols(sid, dts, cols)

        open_price = self.reader.get_value(sid, minute_0, 'open')

        self.assertEquals(10.0, open_price)

        high_price = self.reader.get_value(sid, minute_0, 'high')

        self.assertEquals(20.0, high_price)

        low_price = self.reader.get_value(sid, minute_0, 'low')

        self.assertEquals(30.0, low_price)

        close_price = self.reader.get_value(sid, minute_0, 'close')

        self.assertEquals(40.0, close_price)

        volume_price = self.reader.get_value(sid, minute_0, 'volume')

        self.assertEquals(50.0, volume_price)

        open_price = self.reader.get_value(sid, minute_1, 'open')

        self.assertEquals(11.0, open_price)

        high_price = self.reader.get_value(sid, minute_1, 'high')

        self.assertEquals(21.0, high_price)

        low_price = self.reader.get_value(sid, minute_1, 'low')

        self.assertEquals(31.0, low_price)

        close_price = self.reader.get_value(sid, minute_1, 'close')

        self.assertEquals(41.0, close_price)

        volume_price = self.reader.get_value(sid, minute_1, 'volume')

        self.assertEquals(51.0, volume_price)

    def test_unadjusted_minutes(self):
        """
        Test unadjusted minutes.
        """
        start_minute = self.market_opens[TEST_CALENDAR_START]
        minutes = [
            start_minute, start_minute + Timedelta('1 min'),
            start_minute + Timedelta('2 min')
        ]
        sids = [1, 2]
        data_1 = DataFrame(data={
            'open': [15.0, nan, 15.1],
            'high': [17.0, nan, 17.1],
            'low': [11.0, nan, 11.1],
            'close': [14.0, nan, 14.1],
            'volume': [1000, 0, 1001]
        },
                           index=minutes)
        self.writer.write(sids[0], data_1)

        data_2 = DataFrame(data={
            'open': [25.0, nan, 25.1],
            'high': [27.0, nan, 27.1],
            'low': [21.0, nan, 21.1],
            'close': [24.0, nan, 24.1],
            'volume': [2000, 0, 2001]
        },
                           index=minutes)
        self.writer.write(sids[1], data_2)

        reader = BcolzMinuteBarReader(self.dest)

        columns = ['open', 'high', 'low', 'close', 'volume']
        sids = [sids[0], sids[1]]
        arrays = reader.unadjusted_window(columns, minutes[0], minutes[-1],
                                          sids)

        data = {sids[0]: data_1, sids[1]: data_2}

        for i, col in enumerate(columns):
            for j, sid in enumerate(sids):
                assert_almost_equal(data[sid][col], arrays[i][j])
示例#37
0
 def setUp(self):
     self.tmp_dir = TempDirectory()
     self.tmp_dir.write('foo.txt', 'bar', encoding='utf-8')
示例#38
0
class TestPrepareTarget(TestCase):
    def setUp(self):
        self.dir = TempDirectory()
        self.addCleanup(self.dir.cleanup)
        replace = Replacer()
        replace('workfront.generate.TARGET_ROOT', self.dir.path)
        self.addCleanup(replace.restore)
        self.session = Session('test')

    def test_from_scratch(self):
        path = prepare_target(self.session)

        compare(path, expected=self.dir.getpath('unsupported.py'))
        self.dir.compare(expected=[])

    def test_everything(self):
        self.dir.write('unsupported.py', b'yy')
        path = prepare_target(self.session)

        compare(path, expected=self.dir.getpath('unsupported.py'))
        self.dir.compare(expected=['unsupported.py'])
        compare(self.dir.read('unsupported.py'), b"yy")

    def test_dots_in_version(self):
        path = prepare_target(Session('test', api_version='v4.0'))

        compare(path, expected=self.dir.getpath('v40.py'))
        self.dir.compare(expected=[])
示例#39
0
class FunctionalTest(MockOpenHelper, TestCase):

    base = 'https://api-cl01.attask-ondemand.com/attask/api/unsupported'

    def setUp(self):
        super(FunctionalTest, self).setUp()
        self.log = LogCapture()
        self.addCleanup(self.log.uninstall)
        self.dir = TempDirectory()
        self.addCleanup(self.dir.cleanup)
        self.replace('logging.basicConfig', Mock())
        self.replace('workfront.generate.TARGET_ROOT', self.dir.path)

    def test_functional(self):
        self.replace('sys.argv', ['x'])

        self.server.add(
            url='/metadata',
            params='method=GET',
            response=json.dumps(
                dict(data=dict(objects=dict(
                    SomeThing=dict(objCode='BAR', name='SomeThing'),
                    OtherThing=dict(objCode='FOO', name='OtherThing'),
                )))))

        self.server.add(url='/foo/metadata',
                        params='method=GET',
                        response=json.dumps(
                            dict(data=dict(
                                objCode='FOO',
                                name='OtherThing',
                                fields={
                                    "ID": {},
                                    "anotherField": {}
                                },
                                references={},
                                collections={},
                                actions={},
                            ))))

        self.server.add(
            url='/bar/metadata',
            params='method=GET',
            response=json.dumps(
                dict(data=dict(objCode='BAR',
                               name='SomeThing',
                               fields={
                                   "ID": {},
                                   "theField": {}
                               },
                               references={"accessRules": {}},
                               collections={"assignedTo": {}},
                               actions={
                                   "doSomething": {
                                       "arguments": [{
                                           "name": "anOption",
                                           "type": "Task"
                                       }, {
                                           "name": "options",
                                           "type": "string[]"
                                       }],
                                       "resultType":
                                       "string",
                                       "label":
                                       "doSomething"
                                   }
                               }))))

        with OutputCapture() as output:
            output.disable()
            main()

        output.compare("")

        self.dir.compare(expected=['unsupported.py'])

        compare(self.dir.read('unsupported.py').decode('ascii'),
                expected=u'''\
# generated from https://api-cl01.attask-ondemand.com/attask/api/unsupported/metadata
from ..meta import APIVersion, Object, Field, Reference, Collection

api = APIVersion('unsupported')


class OtherThing(Object):
    code = 'FOO'
    another_field = Field('anotherField')

api.register(OtherThing)


class SomeThing(Object):
    code = 'BAR'
    the_field = Field('theField')
    access_rules = Reference('accessRules')
    assigned_to = Collection('assignedTo')

    def do_something(self, an_option=None, options=None):
        """
        The ``doSomething`` action.

        :param an_option: anOption (type: ``Task``)
        :param options: options (type: ``string[]``)
        :return: ``string``
        """
        params = {}
        if an_option is not None: params['anOption'] = an_option
        if options is not None: params['options'] = options
        data = self.session.put(self.api_url()+'/doSomething', params)
        return data['result']

api.register(SomeThing)
''',
                trailing_whitespace=False)
示例#40
0
class TestDecoratedObjectTypes(MockOpenHelper, TestCase):
    def setUp(self):
        super(TestDecoratedObjectTypes, self).setUp()
        self.dir = TempDirectory()
        self.addCleanup(self.dir.cleanup)

    def test_normal(self):
        base_url = 'https://test.attask-ondemand.com/attask/api/v4.0'
        session = Session('test', api_version='v4.0')
        self.server.add(
            url=base_url + '/metadata',
            params='method=GET',
            response=json.dumps(
                dict(data=dict(objects=dict(
                    SomeThing=dict(objCode='SMTHING', name='SomeThing'))))))
        expected = dict(objCode='SMTHING', name='SomeThing', stuff='a value')
        self.server.add(url=base_url + '/smthing/metadata',
                        params='method=GET',
                        response=json.dumps(dict(data=expected)))
        compare(decorated_object_types(session, None),
                expected=[('SomeThing', 'SMTHING', expected)])

    def test_cache_write(self):
        base_url = 'https://test.attask-ondemand.com/attask/api/v4.0'
        session = Session('test', api_version='v4.0')
        self.server.add(
            url=base_url + '/metadata',
            params='method=GET',
            response=json.dumps(
                dict(data=dict(objects=dict(
                    SomeThing=dict(objCode='SMTHING', name='SomeThing'))))))
        expected = dict(objCode='SMTHING', name='SomeThing', stuff='a value')
        self.server.add(url=base_url + '/smthing/metadata',
                        params='method=GET',
                        response=json.dumps(dict(data=expected)))
        compare(decorated_object_types(session, self.dir.path),
                expected=[('SomeThing', 'SMTHING', expected)])
        self.dir.compare(
            expected=['v4.0_metadata.json', 'v4.0_smthing_metadata.json'])
        compare(json.loads(
            self.dir.read('v4.0_metadata.json').decode('ascii')),
                expected=dict(objects=dict(
                    SomeThing=dict(objCode='SMTHING', name='SomeThing'))))
        compare(json.loads(
            self.dir.read('v4.0_smthing_metadata.json').decode('ascii')),
                expected=expected)

    def test_cache_read(self):
        expected = dict(objCode='SMTHING', name='SomeThing', stuff='a value')

        self.dir.write(
            'v4.0_metadata.json',
            json.dumps(
                dict(objects=dict(
                    SomeThing=dict(objCode='SMTHING', name='SomeThing')))),
            encoding='ascii')

        self.dir.write('v4.0_smthing_metadata.json',
                       json.dumps(expected),
                       encoding='ascii')

        session = Session('test', api_version='v4.0')
        compare(decorated_object_types(session, self.dir.path),
                expected=[('SomeThing', 'SMTHING', expected)])

    def test_unsupported(self):
        base_url = 'https://test.attask-ondemand.com/attask/api/unsupported'
        session = Session('test')
        self.server.add(
            url=base_url + '/metadata',
            params='method=GET',
            response=json.dumps(
                dict(data=dict(objects=dict(
                    SomeThing=dict(objCode='SMTHING', name='SomeThing'))))))
        expected = dict(objCode='SMTHING', name='SomeThing', stuff='a value')
        self.server.add(url=base_url + '/smthing/metadata',
                        params='method=GET',
                        response=json.dumps(dict(data=expected)))
        compare(decorated_object_types(session, None),
                expected=[('SomeThing', 'SMTHING', expected)])

    def test_name_override(self):
        base_url = 'https://test.attask-ondemand.com/attask/api/unsupported'
        session = Session('test')
        self.server.add(
            url=base_url + '/metadata',
            params='method=GET',
            response=json.dumps(
                dict(data=dict(objects=dict(
                    SomeThing=dict(objCode='OPTASK', name='SomeThing'))))))
        expected = dict(objCode='SMTHING', name='SomeThing', stuff='a value')
        self.server.add(url=base_url + '/optask/metadata',
                        params='method=GET',
                        response=json.dumps(dict(data=expected)))
        compare(decorated_object_types(session, None),
                expected=[('Issue', 'OPTASK', expected)])
示例#41
0
    def test__run_step_pass_image_tar_file_exists(self, buildah_mock):
        with TempDirectory() as temp_dir:
            parent_work_dir_path = os.path.join(temp_dir.path, 'working')
            temp_dir.write('Dockerfile', b'''testing''')

            artifact_config = {
                'container-image-version': {
                    'description': '',
                    'value': '1.0-123abc'
                },
            }
            workflow_result = self.setup_previous_result(
                parent_work_dir_path, artifact_config)

            step_config = {
                'containers-config-auth-file': 'buildah-auth.json',
                'imagespecfile': 'Dockerfile',
                'context': temp_dir.path,
                'tls-verify': True,
                'format': 'oci',
                'service-name': 'service-name',
                'application-name': 'app-name'
            }
            step_implementer = self.create_step_implementer(
                step_config=step_config,
                step_name='create-container-image',
                implementer='Buildah',
                workflow_result=workflow_result,
                parent_work_dir_path=parent_work_dir_path)

            step_implementer.write_working_file(
                'image-app-name-service-name-1.0-123abc.tar')

            result = step_implementer._run_step()

            expected_step_result = StepResult(
                step_name='create-container-image',
                sub_step_name='Buildah',
                sub_step_implementer_name='Buildah')
            expected_step_result.add_artifact(
                name='container-image-version',
                value='localhost/app-name/service-name:1.0-123abc')
            expected_step_result.add_artifact(
                name='image-tar-file',
                value=
                f'{step_implementer.work_dir_path}/image-app-name-service-name-1.0-123abc.tar'
            )

            buildah_mock.bud.assert_called_once_with(
                '--storage-driver=vfs',
                '--format=oci',
                '--tls-verify=true',
                '--layers',
                '-f',
                'Dockerfile',
                '-t',
                'localhost/app-name/service-name:1.0-123abc',
                '--authfile',
                'buildah-auth.json',
                temp_dir.path,
                _out=sys.stdout,
                _err=sys.stderr,
                _tee='err')

            buildah_mock.push.assert_called_once_with(
                '--storage-driver=vfs',
                'localhost/app-name/service-name:1.0-123abc',
                f'docker-archive:{step_implementer.work_dir_path}/image-app-name-service-name-1.0-123abc.tar',
                _out=sys.stdout,
                _err=sys.stderr,
                _tee='err')
            self.assertEqual(result, expected_step_result)
 def setUp(self):
     self.temp_dir = TempDirectory()
示例#43
0
 def setUp(self):
     super(TestDecoratedObjectTypes, self).setUp()
     self.dir = TempDirectory()
     self.addCleanup(self.dir.cleanup)
示例#44
0
 def setUp(self):
     """
     Set up temporary test directory and mock S3 bucket connection
     """
     # Test metadata parameters
     self.channel_idx = 1
     self.slice_idx = 2
     self.time_idx = 3
     self.channel_name = "TESTCHANNEL"
     # Mock S3 dir
     self.storage_dir = "raw_frames/ISP-2005-06-09-20-00-00-0001"
     # Create temporary directory and write temp image
     self.tempdir = TempDirectory()
     self.temp_path = self.tempdir.path
     # Temporary frame
     self.im = np.ones((10, 15), dtype=np.uint16)
     self.im[2:5, 3:12] = 50000
     # Metadata
     mmmetadata = self._get_mmmeta()
     ijmeta = self._get_ijmeta()
     extra_tags = [('MicroManagerMetadata', 's', 0, mmmetadata, True)]
     # Save test ome tif file
     self.file_path1 = os.path.join(self.temp_path, "test_Pos1.ome.tif")
     tifffile.imsave(
         self.file_path1,
         self.im,
         ijmetadata=ijmeta,
         extratags=extra_tags,
     )
     mmmetadata = self._get_mmmeta(pos_idx=3)
     extra_tags = [('MicroManagerMetadata', 's', 0, mmmetadata, True)]
     # Save test ome tif file
     self.file_path3 = os.path.join(self.temp_path, "test_Pos3.ome.tif")
     tifffile.imsave(
         self.file_path3,
         self.im,
         ijmetadata=ijmeta,
         extratags=extra_tags,
     )
     # Setup mock S3 bucket
     self.mock = mock_s3()
     self.mock.start()
     self.conn = boto3.resource('s3', region_name='us-east-1')
     self.bucket_name = 'czbiohub-imaging'
     self.conn.create_bucket(Bucket=self.bucket_name)
     # Instantiate file parser class
     self.storage_class = aux_utils.get_storage_class('s3')
     self.frames_inst = ometif_splitter.OmeTiffSplitter(
         data_path=self.temp_path,
         storage_dir="raw_frames/ISP-2005-06-09-20-00-00-0001",
         storage_class=self.storage_class,
     )
     # Get path to json schema file
     dir_name = os.path.dirname(__file__)
     self.schema_file_path = os.path.realpath(
         os.path.join(dir_name, '..', '..', 'metadata_schema.json'), )
     # Upload data
     self.frames_inst.get_frames_and_metadata(
         schema_filename=self.schema_file_path,
         positions='[1, 3]',
     )
示例#45
0
    def transaction_sim(self, **params):
        """This is a utility method that asserts expected
        results for conversion of orders to transactions given a
        trade history
        """
        trade_count = params['trade_count']
        trade_interval = params['trade_interval']
        order_count = params['order_count']
        order_amount = params['order_amount']
        order_interval = params['order_interval']
        expected_txn_count = params['expected_txn_count']
        expected_txn_volume = params['expected_txn_volume']

        # optional parameters
        # ---------------------
        # if present, alternate between long and short sales
        alternate = params.get('alternate')

        # if present, expect transaction amounts to match orders exactly.
        complete_fill = params.get('complete_fill')

        asset1 = self.asset_finder.retrieve_asset(1)
        with TempDirectory() as tempdir:

            if trade_interval < timedelta(days=1):
                sim_params = factory.create_simulation_parameters(
                    start=self.start, end=self.end, data_frequency="minute")

                minutes = self.trading_calendar.minutes_window(
                    sim_params.first_open,
                    int((trade_interval.total_seconds() / 60) * trade_count) +
                    100)

                price_data = np.array([10.1] * len(minutes))
                assets = {
                    asset1.sid:
                    pd.DataFrame({
                        "open": price_data,
                        "high": price_data,
                        "low": price_data,
                        "close": price_data,
                        "volume": np.array([100] * len(minutes)),
                        "dt": minutes
                    }).set_index("dt")
                }

                write_bcolz_minute_data(
                    self.trading_calendar,
                    self.trading_calendar.sessions_in_range(
                        self.trading_calendar.minute_to_session_label(
                            minutes[0]),
                        self.trading_calendar.minute_to_session_label(
                            minutes[-1])),
                    tempdir.path,
                    iteritems(assets),
                )

                equity_minute_reader = BcolzMinuteBarReader(tempdir.path)

                data_portal = DataPortal(
                    self.asset_finder,
                    self.trading_calendar,
                    first_trading_day=equity_minute_reader.first_trading_day,
                    equity_minute_reader=equity_minute_reader,
                )
            else:
                sim_params = factory.create_simulation_parameters(
                    data_frequency="daily")

                days = sim_params.sessions

                assets = {
                    1:
                    pd.DataFrame(
                        {
                            "open": [10.1] * len(days),
                            "high": [10.1] * len(days),
                            "low": [10.1] * len(days),
                            "close": [10.1] * len(days),
                            "volume": [100] * len(days),
                            "day": [day.value for day in days]
                        },
                        index=days)
                }

                path = os.path.join(tempdir.path, "testdata.bcolz")
                BcolzDailyBarWriter(path, self.trading_calendar, days[0],
                                    days[-1]).write(assets.items())

                equity_daily_reader = BcolzDailyBarReader(path)

                data_portal = DataPortal(
                    self.asset_finder,
                    self.trading_calendar,
                    first_trading_day=equity_daily_reader.first_trading_day,
                    equity_daily_reader=equity_daily_reader,
                )

            if "default_slippage" not in params or \
               not params["default_slippage"]:
                slippage_func = FixedBasisPointsSlippage()
            else:
                slippage_func = None

            blotter = SimulationBlotter(slippage_func)

            start_date = sim_params.first_open

            if alternate:
                alternator = -1
            else:
                alternator = 1

            tracker = MetricsTracker(
                trading_calendar=self.trading_calendar,
                first_session=sim_params.start_session,
                last_session=sim_params.end_session,
                capital_base=sim_params.capital_base,
                emission_rate=sim_params.emission_rate,
                data_frequency=sim_params.data_frequency,
                asset_finder=self.asset_finder,
                metrics=load_metrics_set('none'),
            )

            # replicate what tradesim does by going through every minute or day
            # of the simulation and processing open orders each time
            if sim_params.data_frequency == "minute":
                ticks = minutes
            else:
                ticks = days

            transactions = []

            order_list = []
            order_date = start_date
            for tick in ticks:
                blotter.current_dt = tick
                if tick >= order_date and len(order_list) < order_count:
                    # place an order
                    direction = alternator**len(order_list)
                    order_id = blotter.order(
                        asset1,
                        order_amount * direction,
                        MarketOrder(),
                    )
                    order_list.append(blotter.orders[order_id])
                    order_date = order_date + order_interval
                    # move after market orders to just after market next
                    # market open.
                    if order_date.hour >= 21:
                        if order_date.minute >= 00:
                            order_date = order_date + timedelta(days=1)
                            order_date = order_date.replace(hour=14, minute=30)
                else:
                    bar_data = BarData(
                        data_portal=data_portal,
                        simulation_dt_func=lambda: tick,
                        data_frequency=sim_params.data_frequency,
                        trading_calendar=self.trading_calendar,
                        restrictions=NoRestrictions(),
                    )
                    txns, _, closed_orders = blotter.get_transactions(bar_data)
                    for txn in txns:
                        tracker.process_transaction(txn)
                        transactions.append(txn)

                    blotter.prune_orders(closed_orders)

            for i in range(order_count):
                order = order_list[i]
                self.assertEqual(order.asset, asset1)
                self.assertEqual(order.amount, order_amount * alternator**i)

            if complete_fill:
                self.assertEqual(len(transactions), len(order_list))

            total_volume = 0
            for i in range(len(transactions)):
                txn = transactions[i]
                total_volume += txn.amount
                if complete_fill:
                    order = order_list[i]
                    self.assertEqual(order.amount, txn.amount)

            self.assertEqual(total_volume, expected_txn_volume)

            self.assertEqual(len(transactions), expected_txn_count)

            if total_volume == 0:
                self.assertRaises(KeyError, lambda: tracker.positions[asset1])
            else:
                cumulative_pos = tracker.positions[asset1]
                self.assertEqual(total_volume, cumulative_pos.amount)

            # the open orders should not contain the asset.
            oo = blotter.open_orders
            self.assertNotIn(asset1, oo,
                             "Entry is removed when no open orders")
示例#46
0
class TestLocalFile(unittest.TestCase):
    """
    python -m unittest -v test.util.test_file_strategies.TestLocalFile
    """
    def setUp(self):
        self.tmp_dir = TempDirectory()
        self.tmp_dir.write('foo.txt', 'bar', encoding='utf-8')

    def tearDown(self):
        self.tmp_dir.cleanup()

    def test_attributes(self):
        """should build attribute data to describe file"""
        path = os.path.join(self.tmp_dir.path, 'foo.txt')
        subject = LocalFile(path)

        self.assertEqual(subject.protocol, 'file://')
        self.assertEqual(subject.basename, 'foo.txt')
        self.assertEqual(subject.dir_path, self.tmp_dir.path)

    def test_get_contents(self):
        """should return contents as bytes from local file"""
        path = os.path.join(self.tmp_dir.path, 'foo.txt')
        subject = LocalFile(path)

        actual = subject.get_contents().getvalue()
        expected = b'bar'

        self.assertEqual(actual, expected)

    def test_put_byte_contents(self):
        """should store contents passed in as bytes to local file"""
        path = os.path.join(self.tmp_dir.path, 'my_dir', 'foo.txt')
        subject = LocalFile(path)

        subject.put_contents(b'baz')
        actual = self.tmp_dir.read('my_dir/foo.txt', encoding='utf-8')
        expected = 'baz'

        self.assertEqual(actual, expected)

    def test_put_str_contents(self):
        """should store contents passed in as str to local file"""
        path = os.path.join(self.tmp_dir.path, 'my_dir', 'foo.txt')
        subject = LocalFile(path)

        subject.put_contents('baz')
        actual = self.tmp_dir.read('my_dir/foo.txt', encoding='utf-8')
        expected = 'baz'

        self.assertEqual(actual, expected)

    def test_exists(self):
        """should determine if a local file exists or not"""
        present_path = os.path.join(self.tmp_dir.path, 'foo.txt')
        absent_path = os.path.join(self.tmp_dir.path, 'missing.txt')
        present_subject = LocalFile(present_path)
        absent_subject = LocalFile(absent_path)

        self.assertTrue(present_subject.exists())
        self.assertFalse(absent_subject.exists())
示例#47
0
class BcolzDailyBarTestCase(TestCase):
    @classmethod
    def setUpClass(cls):
        all_trading_days = TradingEnvironment().trading_days
        cls.trading_days = all_trading_days[all_trading_days.get_loc(
            TEST_CALENDAR_START):all_trading_days.get_loc(TEST_CALENDAR_STOP) +
                                            1]

    def setUp(self):

        self.asset_info = EQUITY_INFO
        self.writer = SyntheticDailyBarWriter(
            self.asset_info,
            self.trading_days,
        )

        self.dir_ = TempDirectory()
        self.dir_.create()
        self.dest = self.dir_.getpath('daily_equity_pricing.bcolz')

    def tearDown(self):
        self.dir_.cleanup()

    @property
    def assets(self):
        return self.asset_info.index

    def trading_days_between(self, start, end):
        return self.trading_days[self.trading_days.slice_indexer(start, end)]

    def asset_start(self, asset_id):
        return self.writer.asset_start(asset_id)

    def asset_end(self, asset_id):
        return self.writer.asset_end(asset_id)

    def dates_for_asset(self, asset_id):
        start, end = self.asset_start(asset_id), self.asset_end(asset_id)
        return self.trading_days_between(start, end)

    def test_write_ohlcv_content(self):
        result = self.writer.write(self.dest, self.trading_days, self.assets)
        for column in SyntheticDailyBarWriter.OHLCV:
            idx = 0
            data = result[column][:]
            multiplier = 1 if column == 'volume' else 1000
            for asset_id in self.assets:
                for date in self.dates_for_asset(asset_id):
                    self.assertEqual(
                        SyntheticDailyBarWriter.expected_value(
                            asset_id, date, column) * multiplier,
                        data[idx],
                    )
                    idx += 1
            self.assertEqual(idx, len(data))

    def test_write_day_and_id(self):
        result = self.writer.write(self.dest, self.trading_days, self.assets)
        idx = 0
        ids = result['id']
        days = result['day']
        for asset_id in self.assets:
            for date in self.dates_for_asset(asset_id):
                self.assertEqual(ids[idx], asset_id)
                self.assertEqual(date, seconds_to_timestamp(days[idx]))
                idx += 1

    def test_write_attrs(self):
        result = self.writer.write(self.dest, self.trading_days, self.assets)
        expected_first_row = {
            '1': 0,
            '2': 5,  # Asset 1 has 5 trading days.
            '3': 12,  # Asset 2 has 7 trading days.
            '4': 33,  # Asset 3 has 21 trading days.
            '5': 44,  # Asset 4 has 11 trading days.
            '6': 49,  # Asset 5 has 5 trading days.
        }
        expected_last_row = {
            '1': 4,
            '2': 11,
            '3': 32,
            '4': 43,
            '5': 48,
            '6': 57,  # Asset 6 has 9 trading days.
        }
        expected_calendar_offset = {
            '1': 0,  # Starts on 6-01, 1st trading day of month.
            '2': 15,  # Starts on 6-22, 16th trading day of month.
            '3': 1,  # Starts on 6-02, 2nd trading day of month.
            '4': 0,  # Starts on 6-01, 1st trading day of month.
            '5': 9,  # Starts on 6-12, 10th trading day of month.
            '6': 10,  # Starts on 6-15, 11th trading day of month.
        }
        self.assertEqual(result.attrs['first_row'], expected_first_row)
        self.assertEqual(result.attrs['last_row'], expected_last_row)
        self.assertEqual(
            result.attrs['calendar_offset'],
            expected_calendar_offset,
        )
        assert_index_equal(
            self.trading_days,
            DatetimeIndex(result.attrs['calendar'], tz='UTC'),
        )

    def _check_read_results(self, columns, assets, start_date, end_date):
        table = self.writer.write(self.dest, self.trading_days, self.assets)
        reader = BcolzDailyBarReader(table)
        results = reader.load_raw_arrays(columns, start_date, end_date, assets)
        dates = self.trading_days_between(start_date, end_date)
        for column, result in zip(columns, results):
            assert_array_equal(
                result,
                self.writer.expected_values_2d(
                    dates,
                    assets,
                    column.name,
                ))

    @parameterized.expand([
        ([USEquityPricing.open], ),
        ([USEquityPricing.close, USEquityPricing.volume], ),
        ([USEquityPricing.volume, USEquityPricing.high,
          USEquityPricing.low], ),
        (USEquityPricing.columns, ),
    ])
    def test_read(self, columns):
        self._check_read_results(
            columns,
            self.assets,
            TEST_QUERY_START,
            TEST_QUERY_STOP,
        )

    def test_start_on_asset_start(self):
        """
        Test loading with queries that starts on the first day of each asset's
        lifetime.
        """
        columns = [USEquityPricing.high, USEquityPricing.volume]
        for asset in self.assets:
            self._check_read_results(
                columns,
                self.assets,
                start_date=self.asset_start(asset),
                end_date=self.trading_days[-1],
            )

    def test_start_on_asset_end(self):
        """
        Test loading with queries that start on the last day of each asset's
        lifetime.
        """
        columns = [USEquityPricing.close, USEquityPricing.volume]
        for asset in self.assets:
            self._check_read_results(
                columns,
                self.assets,
                start_date=self.asset_end(asset),
                end_date=self.trading_days[-1],
            )

    def test_end_on_asset_start(self):
        """
        Test loading with queries that end on the first day of each asset's
        lifetime.
        """
        columns = [USEquityPricing.close, USEquityPricing.volume]
        for asset in self.assets:
            self._check_read_results(
                columns,
                self.assets,
                start_date=self.trading_days[0],
                end_date=self.asset_start(asset),
            )

    def test_end_on_asset_end(self):
        """
        Test loading with queries that end on the last day of each asset's
        lifetime.
        """
        columns = [USEquityPricing.close, USEquityPricing.volume]
        for asset in self.assets:
            self._check_read_results(
                columns,
                self.assets,
                start_date=self.trading_days[0],
                end_date=self.asset_end(asset),
            )

    def test_unadjusted_spot_price(self):
        table = self.writer.write(self.dest, self.trading_days, self.assets)
        reader = BcolzDailyBarReader(table)
        # At beginning
        price = reader.spot_price(1, Timestamp('2015-06-01', tz='UTC'),
                                  'close')
        # Synthetic writes price for date.
        self.assertEqual(135630.0, price)

        # Middle
        price = reader.spot_price(1, Timestamp('2015-06-02', tz='UTC'),
                                  'close')
        self.assertEqual(135631.0, price)
        # End
        price = reader.spot_price(1, Timestamp('2015-06-05', tz='UTC'),
                                  'close')
        self.assertEqual(135634.0, price)

        # Another sid at beginning.
        price = reader.spot_price(2, Timestamp('2015-06-22', tz='UTC'),
                                  'close')
        self.assertEqual(235651.0, price)

        # Ensure that volume does not have float adjustment applied.
        volume = reader.spot_price(1, Timestamp('2015-06-02', tz='UTC'),
                                   'volume')
        self.assertEqual(145631, volume)

    def test_unadjusted_spot_price_no_data(self):
        table = self.writer.write(self.dest, self.trading_days, self.assets)
        reader = BcolzDailyBarReader(table)
        # before
        with self.assertRaises(NoDataOnDate):
            reader.spot_price(2, Timestamp('2015-06-08', tz='UTC'), 'close')

        # after
        with self.assertRaises(NoDataOnDate):
            reader.spot_price(4, Timestamp('2015-06-16', tz='UTC'), 'close')

    def test_unadjusted_spot_price_empty_value(self):
        table = self.writer.write(self.dest, self.trading_days, self.assets)
        reader = BcolzDailyBarReader(table)

        # A sid, day and corresponding index into which to overwrite a zero.
        zero_sid = 1
        zero_day = Timestamp('2015-06-02', tz='UTC')
        zero_ix = reader.sid_day_index(zero_sid, zero_day)

        # Write a zero into the synthetic pricing data at the day and sid,
        # so that a read should now return -1.
        # This a little hacky, in lieu of changing the synthetic data set.
        reader._spot_col('close')[zero_ix] = 0

        close = reader.spot_price(zero_sid, zero_day, 'close')
        self.assertEqual(-1, close)
def setUp(test):
    test.globs['temp_dir'] = TempDirectory()
示例#49
0
def test_generateExInv(dirpath, outfilePrefix):

    data = {
        1: {
            'invoice_id': 1,
            'customer_id': 1,
            'invoice_date': datetime.datetime(2020, 4, 18, 0, 0),
            'first_name': 'John',
            'last_name': 'Doe',
            'phone': 4031234567,
            'address': '123 Fake Street',
            'city': 'Calgary',
            'province': 'AB',
            'postal_code': 'T1X1N1',
            'invoice_line_Item_id': 1,
            'product_id': 1,
            'item_ref': 'Item 1',
            'quantity': 3,
            'name': 'Pen',
            'description': 'Ball Pointed, Black Ink',
            'unit_price': 3
        },
        2: {
            'invoice_id': 1,
            'customer_id': 1,
            'invoice_date': datetime.datetime(2020, 4, 18, 0, 0),
            'first_name': 'John',
            'last_name': 'Doe',
            'phone': 4031234567,
            'address': '123 Fake Street',
            'city': 'Calgary',
            'province': 'AB',
            'postal_code': 'T1X1N1',
            'invoice_line_Item_id': 2,
            'product_id': 2,
            'item_ref': 'Item 2',
            'quantity': 1,
            'name': 'Pencil',
            'description': 'Mechanical, 0.3mm',
            'unit_price': 5
        },
        3: {
            'invoice_id': 1,
            'customer_id': 1,
            'invoice_date': datetime.datetime(2020, 4, 18, 0, 0),
            'first_name': 'John',
            'last_name': 'Doe',
            'phone': 4031234567,
            'address': '123 Fake Street',
            'city': 'Calgary',
            'province': 'AB',
            'postal_code': 'T1X1N1',
            'invoice_line_Item_id': 3,
            'product_id': 3,
            'item_ref': 'Item 3',
            'quantity': 1,
            'name': 'Eraser',
            'description': 'White',
            'unit_price': 2
        }
    }

    outfile = f"{outfilePrefix} 1.xlsx"

    with TempDirectory() as d:
        create_status = generateExInv(d.path, outfilePrefix, data)
        filename = os.path.join(d.path, outfile)
        existEX = os.path.exists(filename)
        wb = load_workbook(filename)
        ws = wb.active
        d.listdir()

    assert existEX == True
    assert create_status == True
    assert ws['A1'].value == "Invoice Date:"
    assert ws['B1'].value == datetime.datetime(2020, 4, 18, 0, 0)
    assert ws['A2'].value == "Invoice Number:"
    assert ws['B2'].value == 1
    assert ws['E1'].value == "Billed to:"
    assert ws['F1'].value == "John Doe"

    assert ws['A9'].value == "Item 1"
    assert ws['B9'].value == "Pen"
    assert ws['C9'].value == "Ball Pointed, Black Ink"
示例#50
0
 def test_ensure_path(self):
     with TempDirectory() as d:
         p = utils.ensure_path(d.path, 'foo', 'bar')
         expected = os.path.join(d.path, 'foo', 'bar')
         assert os.path.exists(expected)
         self.assertEqual(p, expected)
 def test_write_unicode(self):
     with TempDirectory() as d:
         d.write('test.file', some_text, 'utf8')
         with open(os.path.join(d.path, 'test.file'), 'rb') as f:
             compare(f.read(), b'\xc2\xa3')
示例#52
0
    def tearDown(self):
        """Tear down temporary folder and file structure"""

        TempDirectory.cleanup_all()
        nose.tools.assert_equal(os.path.isdir(self.temp_path), False)
 def test_write_bytes(self):
     with TempDirectory() as d:
         d.write('test.file', b'\xc2\xa3')
         with open(os.path.join(d.path, 'test.file'), 'rb') as f:
             compare(f.read(), b'\xc2\xa3')
示例#54
0
    def setUp(self):
        """Set up a dir for tiling with flatfield"""

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        # Start frames meta file
        self.meta_name = 'frames_meta.csv'
        frames_meta = aux_utils.make_dataframe()
        # Write images
        self.im = 127 * np.ones((15, 11), dtype=np.uint8)
        self.im2 = 234 * np.ones((15, 11), dtype=np.uint8)
        self.channel_idx = 1
        self.time_idx = 5
        self.pos_idx1 = 7
        self.pos_idx2 = 8
        self.int2str_len = 3

        # Write test images with 4 z and 2 pos idx
        for z in range(15, 20):
            im_name = aux_utils.get_im_name(
                channel_idx=self.channel_idx,
                slice_idx=z,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx1,
            )
            cv2.imwrite(
                os.path.join(self.temp_path, im_name),
                self.im,
            )
            frames_meta = frames_meta.append(
                aux_utils.parse_idx_from_name(im_name),
                ignore_index=True,
            )

        for z in range(15, 20):
            im_name = aux_utils.get_im_name(
                channel_idx=self.channel_idx,
                slice_idx=z,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx2,
            )
            cv2.imwrite(
                os.path.join(self.temp_path, im_name),
                self.im2,
            )
            frames_meta = frames_meta.append(
                aux_utils.parse_idx_from_name(im_name),
                ignore_index=True,
            )

        # Write metadata
        frames_meta.to_csv(
            os.path.join(self.temp_path, self.meta_name),
            sep=',',
        )
        # Add flatfield
        self.flat_field_dir = os.path.join(self.temp_path, 'ff_dir')
        self.tempdir.makedir('ff_dir')
        self.ff_im = 4. * np.ones((15, 11))
        self.ff_name = os.path.join(
            self.flat_field_dir,
            'flat-field_channel-1.npy',
        )
        np.save(self.ff_name, self.ff_im, allow_pickle=True, fix_imports=True)
        # Instantiate tiler class
        self.output_dir = os.path.join(self.temp_path, 'tile_dir')
        self.tile_inst = tile_images.ImageTilerUniform(
            input_dir=self.temp_path,
            output_dir=self.output_dir,
            tile_size=[5, 5],
            step_size=[4, 4],
            depths=3,
            channel_ids=[1],
            normalize_channels=[True],
            flat_field_dir=self.flat_field_dir,
        )
        exp_fnames = [
            'im_c001_z015_t005_p007.png', 'im_c001_z016_t005_p007.png',
            'im_c001_z017_t005_p007.png'
        ]
        self.exp_fnames = [
            os.path.join(self.temp_path, fname) for fname in exp_fnames
        ]
        self.exp_tile_indices = [
            [0, 5, 0, 5],
            [0, 5, 4, 9],
            [0, 5, 6, 11],
            [10, 15, 0, 5],
            [10, 15, 4, 9],
            [10, 15, 6, 11],
            [4, 9, 0, 5],
            [4, 9, 4, 9],
            [4, 9, 6, 11],
            [8, 13, 0, 5],
            [8, 13, 4, 9],
            [8, 13, 6, 11],
        ]

        # create a mask
        mask_dir = os.path.join(self.temp_path, 'mask_dir')
        os.makedirs(mask_dir, exist_ok=True)
        mask_images = np.zeros((15, 11, 5), dtype='bool')
        mask_images[4:12, 4:9, 2:4] = 1

        # write mask images and add meta to frames_meta
        self.mask_channel = 3
        mask_meta = []
        for z in range(5):
            cur_im = mask_images[:, :, z]
            im_name = aux_utils.get_im_name(
                channel_idx=3,
                slice_idx=z + 15,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx1,
                ext='.npy',
            )
            np.save(os.path.join(mask_dir, im_name), cur_im)
            cur_meta = {
                'channel_idx': 3,
                'slice_idx': z + 15,
                'time_idx': self.time_idx,
                'pos_idx': self.pos_idx1,
                'file_name': im_name
            }
            mask_meta.append(cur_meta)
        mask_meta_df = pd.DataFrame.from_dict(mask_meta)
        mask_meta_df.to_csv(os.path.join(mask_dir, 'frames_meta.csv'), sep=',')
        self.mask_dir = mask_dir

        exp_tile_indices = [[0, 5, 0, 5], [0, 5, 4, 9], [0, 5, 6, 11],
                            [10, 15, 0, 5], [10, 15, 4, 9], [10, 15, 6, 11],
                            [4, 9, 0, 5], [4, 9, 4, 9], [4, 9, 6, 11],
                            [8, 13, 0, 5], [8, 13, 4, 9], [8, 13, 6, 11]]
        self.exp_tile_indices = exp_tile_indices
示例#55
0
class TestImageTilerUniform(unittest.TestCase):
    def setUp(self):
        """Set up a dir for tiling with flatfield"""

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        # Start frames meta file
        self.meta_name = 'frames_meta.csv'
        frames_meta = aux_utils.make_dataframe()
        # Write images
        self.im = 127 * np.ones((15, 11), dtype=np.uint8)
        self.im2 = 234 * np.ones((15, 11), dtype=np.uint8)
        self.channel_idx = 1
        self.time_idx = 5
        self.pos_idx1 = 7
        self.pos_idx2 = 8
        self.int2str_len = 3

        # Write test images with 4 z and 2 pos idx
        for z in range(15, 20):
            im_name = aux_utils.get_im_name(
                channel_idx=self.channel_idx,
                slice_idx=z,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx1,
            )
            cv2.imwrite(
                os.path.join(self.temp_path, im_name),
                self.im,
            )
            frames_meta = frames_meta.append(
                aux_utils.parse_idx_from_name(im_name),
                ignore_index=True,
            )

        for z in range(15, 20):
            im_name = aux_utils.get_im_name(
                channel_idx=self.channel_idx,
                slice_idx=z,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx2,
            )
            cv2.imwrite(
                os.path.join(self.temp_path, im_name),
                self.im2,
            )
            frames_meta = frames_meta.append(
                aux_utils.parse_idx_from_name(im_name),
                ignore_index=True,
            )

        # Write metadata
        frames_meta.to_csv(
            os.path.join(self.temp_path, self.meta_name),
            sep=',',
        )
        # Add flatfield
        self.flat_field_dir = os.path.join(self.temp_path, 'ff_dir')
        self.tempdir.makedir('ff_dir')
        self.ff_im = 4. * np.ones((15, 11))
        self.ff_name = os.path.join(
            self.flat_field_dir,
            'flat-field_channel-1.npy',
        )
        np.save(self.ff_name, self.ff_im, allow_pickle=True, fix_imports=True)
        # Instantiate tiler class
        self.output_dir = os.path.join(self.temp_path, 'tile_dir')
        self.tile_inst = tile_images.ImageTilerUniform(
            input_dir=self.temp_path,
            output_dir=self.output_dir,
            tile_size=[5, 5],
            step_size=[4, 4],
            depths=3,
            channel_ids=[1],
            normalize_channels=[True],
            flat_field_dir=self.flat_field_dir,
        )
        exp_fnames = [
            'im_c001_z015_t005_p007.png', 'im_c001_z016_t005_p007.png',
            'im_c001_z017_t005_p007.png'
        ]
        self.exp_fnames = [
            os.path.join(self.temp_path, fname) for fname in exp_fnames
        ]
        self.exp_tile_indices = [
            [0, 5, 0, 5],
            [0, 5, 4, 9],
            [0, 5, 6, 11],
            [10, 15, 0, 5],
            [10, 15, 4, 9],
            [10, 15, 6, 11],
            [4, 9, 0, 5],
            [4, 9, 4, 9],
            [4, 9, 6, 11],
            [8, 13, 0, 5],
            [8, 13, 4, 9],
            [8, 13, 6, 11],
        ]

        # create a mask
        mask_dir = os.path.join(self.temp_path, 'mask_dir')
        os.makedirs(mask_dir, exist_ok=True)
        mask_images = np.zeros((15, 11, 5), dtype='bool')
        mask_images[4:12, 4:9, 2:4] = 1

        # write mask images and add meta to frames_meta
        self.mask_channel = 3
        mask_meta = []
        for z in range(5):
            cur_im = mask_images[:, :, z]
            im_name = aux_utils.get_im_name(
                channel_idx=3,
                slice_idx=z + 15,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx1,
                ext='.npy',
            )
            np.save(os.path.join(mask_dir, im_name), cur_im)
            cur_meta = {
                'channel_idx': 3,
                'slice_idx': z + 15,
                'time_idx': self.time_idx,
                'pos_idx': self.pos_idx1,
                'file_name': im_name
            }
            mask_meta.append(cur_meta)
        mask_meta_df = pd.DataFrame.from_dict(mask_meta)
        mask_meta_df.to_csv(os.path.join(mask_dir, 'frames_meta.csv'), sep=',')
        self.mask_dir = mask_dir

        exp_tile_indices = [[0, 5, 0, 5], [0, 5, 4, 9], [0, 5, 6, 11],
                            [10, 15, 0, 5], [10, 15, 4, 9], [10, 15, 6, 11],
                            [4, 9, 0, 5], [4, 9, 4, 9], [4, 9, 6, 11],
                            [8, 13, 0, 5], [8, 13, 4, 9], [8, 13, 6, 11]]
        self.exp_tile_indices = exp_tile_indices

    def tearDown(self):
        """Tear down temporary folder and file structure"""

        TempDirectory.cleanup_all()
        nose.tools.assert_equal(os.path.isdir(self.temp_path), False)

    def test_init(self):
        """Test init"""

        nose.tools.assert_equal(self.tile_inst.depths, 3)
        nose.tools.assert_equal(self.tile_inst.tile_size, [5, 5])
        nose.tools.assert_equal(self.tile_inst.step_size, [4, 4])
        nose.tools.assert_equal(self.tile_inst.hist_clip_limits, None)
        nose.tools.assert_equal(self.tile_inst.image_format, 'zyx')
        nose.tools.assert_equal(self.tile_inst.num_workers, 4)
        nose.tools.assert_equal(
            self.tile_inst.str_tile_step,
            'tiles_5-5_step_4-4',
        )
        nose.tools.assert_equal(self.tile_inst.channel_ids, [self.channel_idx])
        nose.tools.assert_equal(self.tile_inst.time_ids, [self.time_idx])
        nose.tools.assert_equal(
            self.tile_inst.flat_field_dir,
            self.flat_field_dir,
        )
        # Depth is 3 so first and last frame will not be used
        np.testing.assert_array_equal(
            self.tile_inst.slice_ids,
            np.asarray([16, 17, 18]),
        )
        np.testing.assert_array_equal(
            self.tile_inst.pos_ids,
            np.asarray([7, 8]),
        )
        # channel_depth should be a dict containing depths for each channel
        self.assertListEqual(
            list(self.tile_inst.channel_depth),
            [self.channel_idx],
        )
        nose.tools.assert_equal(
            self.tile_inst.channel_depth[self.channel_idx],
            3,
        )

    def test_tile_dir(self):
        nose.tools.assert_equal(
            self.tile_inst.get_tile_dir(),
            os.path.join(self.output_dir, "tiles_5-5_step_4-4"))

    def test_get_dataframe(self):
        df = self.tile_inst._get_dataframe()
        self.assertListEqual(list(df), [
            'channel_idx', 'slice_idx', 'time_idx', 'file_name', 'pos_idx',
            'row_start', 'col_start'
        ])

    def test_get_flat_field(self):
        flat_field_im = self.tile_inst._get_flat_field(channel_idx=1)
        np.testing.assert_array_equal(flat_field_im, self.ff_im)

    def test_get_tile_indices(self):
        """Test get_tiled_indices"""

        self.tile_inst.tile_stack()
        # Read the saved metadata
        tile_dir = self.tile_inst.get_tile_dir()
        tile_meta = pd.read_csv(os.path.join(tile_dir, "frames_meta.csv"))

        tile_indices = self.tile_inst._get_tile_indices(
            tiled_meta=tile_meta,
            time_idx=self.time_idx,
            channel_idx=self.channel_idx,
            pos_idx=7,
            slice_idx=16)
        exp_tile_indices = np.asarray(self.exp_tile_indices, dtype='uint8')
        row_ids = list(range(len(exp_tile_indices)))
        for ret_idx in tile_indices:
            row_idx = np.where((exp_tile_indices[:, 0] == ret_idx[0])
                               & (exp_tile_indices[:, 1] == ret_idx[1])
                               & (exp_tile_indices[:, 2] == ret_idx[2])
                               & (exp_tile_indices[:, 3] == ret_idx[3]))
            nose.tools.assert_in(row_idx[0], row_ids)

    def test_get_tiled_data(self):
        """Test get_tiled_indices"""

        # no tiles_exist
        tile_meta, tile_indices = self.tile_inst._get_tiled_data()
        nose.tools.assert_equal(tile_indices, None)
        init_df = pd.DataFrame(columns=[
            'channel_idx', 'slice_idx', 'time_idx', 'file_name', 'pos_idx',
            'row_start', 'col_start'
        ])
        pd.testing.assert_frame_equal(tile_meta, init_df)
        # tile exists
        self.tile_inst.tile_stack()
        self.tile_inst.tiles_exist = True
        self.tile_inst.channel_ids = [1, 2]
        tile_meta, _ = self.tile_inst._get_tiled_data()

        exp_tile_meta = []
        for exp_idx in self.exp_tile_indices:
            for z in [16, 17, 18]:
                cur_img_id = 'r{}-{}_c{}-{}_sl{}-{}'.format(
                    exp_idx[0], exp_idx[1], exp_idx[2], exp_idx[3], 0, 3)
                pos1_fname = aux_utils.get_im_name(
                    time_idx=self.time_idx,
                    channel_idx=self.channel_idx,
                    slice_idx=z,
                    pos_idx=self.pos_idx1,
                    extra_field=cur_img_id,
                    ext='.npy',
                )
                pos1_meta = {
                    'channel_idx': self.channel_idx,
                    'slice_idx': z,
                    'time_idx': self.time_idx,
                    'file_name': pos1_fname,
                    'pos_idx': self.pos_idx1,
                    'row_start': exp_idx[0],
                    'col_start': exp_idx[2]
                }
                exp_tile_meta.append(pos1_meta)
                pos2_fname = aux_utils.get_im_name(
                    time_idx=self.time_idx,
                    channel_idx=self.channel_idx,
                    slice_idx=z,
                    pos_idx=self.pos_idx2,
                    extra_field=cur_img_id,
                    ext='.npy',
                )
                pos2_meta = {
                    'channel_idx': self.channel_idx,
                    'slice_idx': z,
                    'time_idx': self.time_idx,
                    'file_name': pos2_fname,
                    'pos_idx': self.pos_idx2,
                    'row_start': exp_idx[0],
                    'col_start': exp_idx[2]
                }
                exp_tile_meta.append(pos2_meta)
        exp_tile_meta_df = pd.DataFrame.from_dict(exp_tile_meta)
        exp_tile_meta_df = exp_tile_meta_df.sort_values(by=['file_name'])
        exp_tile_meta_df.reset_index(drop=True, inplace=True)
        tile_meta = tile_meta.sort_values(by=['file_name'])
        tile_meta.reset_index(drop=True, inplace=True)
        pd.testing.assert_frame_equal(tile_meta, exp_tile_meta_df)

    def test_get_input_fnames(self):
        """Test get_input_fnames"""

        im_fnames = self.tile_inst._get_input_fnames(
            time_idx=self.time_idx,
            channel_idx=self.channel_idx,
            slice_idx=16,
            pos_idx=self.pos_idx1)
        nose.tools.assert_list_equal(self.exp_fnames, im_fnames)

    def test_get_crop_args(self):
        """Test get_crop_tile_args with task_type=crop"""

        cur_args = self.tile_inst.get_crop_tile_args(
            channel_idx=self.channel_idx,
            time_idx=self.time_idx,
            slice_idx=16,
            pos_idx=7,
            task_type='crop',
            tile_indices=self.exp_tile_indices)
        nose.tools.assert_list_equal(list(cur_args[0]), self.exp_fnames)
        nose.tools.assert_equal(cur_args[1], self.ff_name)
        nose.tools.assert_equal(cur_args[2], None)
        nose.tools.assert_equal(cur_args[3], self.time_idx)
        nose.tools.assert_equal(cur_args[4], self.channel_idx)
        nose.tools.assert_equal(cur_args[5], 7)
        nose.tools.assert_equal(cur_args[6], 16)
        nose.tools.assert_equal(cur_args[8], 'zyx')
        nose.tools.assert_equal(cur_args[9], self.tile_inst.tile_dir)
        nose.tools.assert_equal(cur_args[10], self.int2str_len)

    def test_tile_stack(self):
        """Test tile_stack"""

        self.tile_inst.tile_stack()
        # Read and validate the saved metadata
        tile_dir = self.tile_inst.get_tile_dir()
        frames_meta = pd.read_csv(os.path.join(tile_dir, 'frames_meta.csv'))

        self.assertSetEqual(set(frames_meta.channel_idx.tolist()), {1})
        self.assertSetEqual(set(frames_meta.slice_idx.tolist()), {16, 17, 18})
        self.assertSetEqual(set(frames_meta.time_idx.tolist()), {5})
        self.assertSetEqual(set(frames_meta.pos_idx.tolist()), {7, 8})
        # 15 rows and step size 4, so it can take 3 full steps and 1 short step
        self.assertSetEqual(set(frames_meta.row_start.tolist()), {0, 4, 8, 10})
        # 11 cols and step size 4, so it can take 2 full steps and 1 short step
        self.assertSetEqual(set(frames_meta.col_start.tolist()), {0, 4, 6})

        # Read and validate tiles
        im_val = np.mean(norm_util.zscore(self.im / self.ff_im))
        im_norm = im_val * np.ones((3, 5, 5))
        im_val = np.mean(norm_util.zscore(self.im2 / self.ff_im))
        im2_norm = im_val * np.ones((3, 5, 5))
        for i, row in frames_meta.iterrows():
            tile = np.load(os.path.join(tile_dir, row.file_name))
            if row.pos_idx == 7:
                np.testing.assert_array_equal(tile, im_norm)
            else:
                np.testing.assert_array_equal(tile, im2_norm)

    def test_get_tile_args(self):
        """Test get_crop_tile_args with task_type=tile"""

        self.tile_inst.mask_depth = 3
        cur_args = self.tile_inst.get_crop_tile_args(
            channel_idx=self.mask_channel,
            time_idx=self.time_idx,
            slice_idx=16,
            pos_idx=7,
            task_type='tile',
            mask_dir=self.mask_dir,
            min_fraction=0.3)

        exp_fnames = [
            'im_c003_z015_t005_p007.npy', 'im_c003_z016_t005_p007.npy',
            'im_c003_z017_t005_p007.npy'
        ]
        exp_fnames = [
            os.path.join(self.mask_dir, fname) for fname in exp_fnames
        ]

        nose.tools.assert_list_equal(list(cur_args[0]), exp_fnames)
        # flat field fname is None
        nose.tools.assert_equal(cur_args[1], None)
        # hist clip limits is None
        nose.tools.assert_equal(cur_args[2], None)
        nose.tools.assert_equal(cur_args[3], self.time_idx)
        nose.tools.assert_equal(cur_args[4], self.mask_channel)
        nose.tools.assert_equal(cur_args[5], 7)
        nose.tools.assert_equal(cur_args[6], 16)
        nose.tools.assert_list_equal(cur_args[7], self.tile_inst.tile_size)
        nose.tools.assert_list_equal(cur_args[8], self.tile_inst.step_size)
        nose.tools.assert_equal(cur_args[9], 0.3)
        nose.tools.assert_equal(cur_args[10], 'zyx')
        nose.tools.assert_equal(cur_args[11], self.tile_inst.tile_dir)
        nose.tools.assert_equal(cur_args[12], self.int2str_len)

        # not a mask channel
        cur_args = self.tile_inst.get_crop_tile_args(
            channel_idx=self.channel_idx,
            time_idx=self.time_idx,
            slice_idx=16,
            pos_idx=7,
            task_type='tile')
        nose.tools.assert_list_equal(list(cur_args[0]), self.exp_fnames)

        exp_ff_fname = os.path.join(
            self.flat_field_dir,
            'flat-field_channel-{}.npy'.format(self.channel_idx),
        )
        nose.tools.assert_equal(cur_args[1], exp_ff_fname)
        nose.tools.assert_equal(cur_args[9], None)

    def test_tile_mask_stack(self):
        """Test tile_mask_stack"""

        self.tile_inst.pos_ids = [7]
        self.tile_inst.normalize_channels = [True, True, True, True]

        # use the saved masks to tile other channels
        self.tile_inst.tile_mask_stack(mask_dir=self.mask_dir,
                                       mask_channel=3,
                                       min_fraction=0.5,
                                       mask_depth=3)

        # Read and validate the saved metadata
        tile_dir = self.tile_inst.get_tile_dir()
        frames_meta = pd.read_csv(os.path.join(tile_dir, 'frames_meta.csv'))

        self.assertSetEqual(set(frames_meta.channel_idx.tolist()), {1, 3})
        self.assertSetEqual(set(frames_meta.slice_idx.tolist()), {17, 18})
        self.assertSetEqual(set(frames_meta.time_idx.tolist()), {5})
        self.assertSetEqual(set(frames_meta.pos_idx.tolist()), {self.pos_idx1})

        # with vf >= 0.5, 4 tiles will be saved for mask channel & [1]
        # [4,9,4,9,17], [8,13,4,9,17], [4,9,4,9,18], [8,13,4,9,18]
        nose.tools.assert_equal(len(frames_meta), 8)
        nose.tools.assert_list_equal(
            frames_meta['row_start'].unique().tolist(), [4, 8])
        nose.tools.assert_equal(frames_meta['col_start'].unique().tolist(),
                                [4])
示例#56
0
class TestSetting():
    """ Test the setting class """
    def setUp(self):
        """ Setup test fixtures """

        self.setting_id = 42
        self.setting_name = 'test_setting_name'
        self.setting_position = 0
        self.dt = {
            'type': 'string',
            'include':
            '#include <stdio.h>\n#include "foo.h"\n#include <stdlib.h>',
            'define': '',
            'init': '',
            'activation': ''
        }
        self.setting_code = 'code'
        self.function_ref = 1
        self.tempdir = TempDirectory()
        self.work_dir = self.tempdir.path

    def tearDown(self):
        """ Remove artefacts """
        self.tempdir.cleanup()

    def test_get_name(self):
        """get_name should return the supplied name"""
        TEST_NAME = 'test_name'
        setting = Setting(self.setting_id, TEST_NAME, self.setting_position,
                          self.dt, self.setting_code, self.function_ref,
                          self.work_dir)

        assert_equal(setting.get_name(), TEST_NAME)

    def test_get_id(self):
        """get_id should return the supplied id"""
        TEST_ID = 43
        setting = Setting(TEST_ID, self.setting_name, self.setting_position,
                          self.dt, self.setting_code, self.function_ref,
                          self.work_dir)

        assert_equal(setting.get_id(), TEST_ID)

    def test_get_position(self):
        """get_position should return the supplied position"""
        TEST_POSITION = 1
        setting = Setting(self.setting_id, self.setting_name, TEST_POSITION,
                          self.dt, self.setting_code, self.function_ref,
                          self.work_dir)

        assert_equal(setting.get_position(), TEST_POSITION)

    def test_get_code(self):
        """get_code should return the supplied code"""
        TEST_CODE = 'if (foo) { bar; }'
        setting = Setting(self.setting_id, self.setting_name,
                          self.setting_position, self.dt, TEST_CODE,
                          self.function_ref, self.work_dir)

        assert_equal(setting.get_code(), TEST_CODE)

    def test_get_function_ref(self):
        """get_function_ref should return the supplied function_ref"""
        TEST_F_REF = 22
        setting = Setting(self.setting_id, self.setting_name,
                          self.setting_position, self.dt, self.setting_code,
                          TEST_F_REF, self.work_dir)

        assert_equal(setting.get_function_ref(), TEST_F_REF)

    def test_get_datatype(self):
        """get_datatype should return the supplied datatype record"""
        TEST_DT = self.dt
        setting = Setting(self.setting_id, self.setting_name,
                          self.setting_position, TEST_DT, self.setting_code,
                          self.function_ref, self.work_dir)

        assert_equal(setting.get_datatype(), TEST_DT)

    def test_ballista_settings_get_removed(self):
        """all ballista includes should be removed from the datatype include
        string """
        setting = Setting(self.setting_id, self.setting_name,
                          self.setting_position, self.dt, self.setting_code,
                          self.function_ref, self.work_dir)

        assert_equal(setting.get_datatype()['include'],
                     '#include <stdio.h>\n\n#include <stdlib.h>')

    def test_generate_files(self):
        """ generating the setting files should create header and cpp files """
        setting = Setting(self.setting_id, 't_name_42', self.setting_position,
                          self.dt, self.setting_code, self.function_ref,
                          self.work_dir)
        setting.generate_files()
        self.tempdir.check('t_name_42.cpp', 't_name_42.h')
 def test_compare_path_tuple(self):
     with TempDirectory() as d:
         d.write('a/b/c', b'')
         d.compare(path=('a', 'b'), expected=['c'])
示例#58
0
    def test_multiple_config_files_from_dir(self):
        args = ['--step', 'write-config-as-results']

        with TempDirectory() as temp_dir:
            config_files = [{
                'name':
                'tssc-config1.yaml',
                'contents':
                '''---
                        tssc-config:
                            write-config-as-results:
                                implementer: 'tests.helpers.sample_step_implementers.WriteConfigAsResultsStepImplementer'
                                config:
                                    required-config-key: 'value'
                    '''
            }, {
                'name':
                'foo/a.yaml',
                'contents':
                '''---
                        tssc-config:
                            write-config-as-results:
                                implementer: 'tests.helpers.sample_step_implementers.WriteConfigAsResultsStepImplementer'
                                config:
                                    keya: "a"

                    '''
            }, {
                'name':
                'foo/b.yaml',
                'contents':
                '''---
                        tssc-config:
                            write-config-as-results:
                                implementer: 'tests.helpers.sample_step_implementers.WriteConfigAsResultsStepImplementer'
                                config:
                                    keyb: "b"
                    '''
            }, {
                'name':
                'foo/bar/c.yaml',
                'contents':
                '''---
                        tssc-config:
                            write-config-as-results:
                                implementer: 'tests.helpers.sample_step_implementers.WriteConfigAsResultsStepImplementer'
                                config:
                                    keyc: "c"
                    '''
            }, {
                'name':
                'foo/bar2/c2.yaml',
                'contents':
                '''---
                        tssc-config:
                            write-config-as-results:
                                implementer: 'tests.helpers.sample_step_implementers.WriteConfigAsResultsStepImplementer'
                                config:
                                    keyc2: "c2"
                    '''
            }, {
                'name':
                'foo/bar/meh/d.yaml',
                'contents':
                '''---
                        tssc-config:
                            write-config-as-results:
                                implementer: 'tests.helpers.sample_step_implementers.WriteConfigAsResultsStepImplementer'
                                config:
                                    keyd: "d"
                    '''
            }]
            for config_file in config_files:
                config_file_name = config_file['name']
                config_file_contents = config_file['contents']

                temp_dir.write(config_file_name,
                               bytes(config_file_contents, 'utf-8'))

            args.append('--config')
            args.append(os.path.join(temp_dir.path, 'foo'))
            args.append(os.path.join(temp_dir.path, 'tssc-config1.yaml'))
            self._run_main_test(
                args, None, None, {
                    'tssc-results': {
                        'write-config-as-results': {
                            'keya': 'a',
                            'keyb': 'b',
                            'keyc': 'c',
                            'keyc2': 'c2',
                            'keyd': 'd',
                            'required-config-key': 'value'
                        }
                    }
                })
class GPSPiTests(unittest.TestCase):
    def setUp(self):
        self.temp_dir = TempDirectory()

    def tearDown(self):
        self.temp_dir.cleanup()

    def test_init_no_logs(self, mock_port):

        # Replace real object os.environ with mock dictionary
        with patch.dict(
            os.environ,
            {
                "GPS_LOG_FILE": "logger.txt",
                "LOG_DIRECTORY": self.temp_dir.path,
                "GPS_PORT": "/dev/serial0",
                "GPS_BAUDRATE": "9600",
            },
        ):
            gps_reader = GPSReader()
            mock_port.assert_called_with(
                os.environ["GPS_PORT"], os.environ["GPS_BAUDRATE"],
            )

            self.assertTrue(gps_reader.logging is not None)
            self.assertTrue(gps_reader.logging.name == "GPS_LOG_FILE")
            self.assertIsInstance(gps_reader.logging, Logger)

    def test_init_logs(self, mock_port):

        with patch.dict(
            os.environ,
            {
                "GPS_HAT_LOG_FILE": "logger.txt",
                "LOG_DIRECTORY": self.temp_dir.path,
                "GPS_PORT": "/dev/serial0",
                "GPS_BAUDRATE": "9600",
            },
        ):

            gps_reader = GPSReader("GPS_HAT_LOG_FILE")
            mock_port.assert_called_with(
                os.environ["GPS_PORT"], os.environ["GPS_BAUDRATE"],
            )

            self.assertTrue(gps_reader.logging is not None)
            self.assertTrue(gps_reader.logging.name == "GPS_HAT_LOG_FILE")
            self.assertIsInstance(gps_reader.logging, Logger)

    @patch("hardware.gpsPi.gps_reader.date_str_with_current_timezone")
    def test_get_location_valid_data(self, mock_date, mock_port):

        mock_port.return_value.inWaiting.return_value = 1
        mock_port.return_value.readline.return_value = (
            "b'$GPRMC,194509.000,A,4042.6142,N,07400.4168,W,2.03,221.11,160412,,,A*77"
        )
        mock_date.return_value = "example date"

        with patch.dict(
            os.environ,
            {
                "GPS_LOG_FILE": "logger.txt",
                "LOG_DIRECTORY": self.temp_dir.path,
                "GPS_PORT": "/dev/serial0",
                "GPS_BAUDRATE": "9600",
            },
        ):

            expected_data = {}
            expected_data["sensor_id"] = 1
            expected_data["values"] = {
                "latitude": 40.71023666666667,
                "longitude": -74.00694666666666,
            }
            expected_data["date"] = "example date"

            gps_reader = GPSReader()
            data = gps_reader.get_geolocation()

            mock_port.return_value.inWaiting.assert_called()
            mock_port.return_value.readline.assert_called()

            self.assertEqual(expected_data, data)

    @patch("hardware.gpsPi.gps_reader.date_str_with_current_timezone")
    def test_get_location_other_valid_data(self, mock_date, mock_port):

        mock_port.return_value.inWaiting.return_value = 1
        mock_port.return_value.readline.return_value = (
            "b'$GPRMC,194509.000,A,4042.6142,S,07400.4168,W,2.03,221.11,160412,,,A*77"
        )
        mock_date.return_value = "example date"

        with patch.dict(
            os.environ,
            {
                "GPS_LOG_FILE": "logger.txt",
                "LOG_DIRECTORY": self.temp_dir.path,
                "GPS_PORT": "/dev/serial0",
                "GPS_BAUDRATE": "9600",
            },
        ):

            expected_data = {}
            expected_data["sensor_id"] = 1
            expected_data["values"] = {
                "latitude": -40.71023666666667,
                "longitude": -74.00694666666666,
            }
            expected_data["date"] = "example date"

            gps_reader = GPSReader()
            data = gps_reader.get_geolocation()

            mock_port.return_value.inWaiting.assert_called()
            mock_port.return_value.readline.assert_called()

            self.assertEqual(expected_data, data)

    def test_get_location_invalid_nmeatype(self, mock_port):

        mock_port.return_value.inWaiting.return_value = 1
        mock_port.return_value.readline.return_value = (
            "b'$GPGGA,194509.000,A,4042.6142,N,07400.4168,W,2.03,221.11,160412,,,A*77"
        )

        with patch.dict(
            os.environ,
            {
                "GPS_LOG_FILE": "logger.txt",
                "LOG_DIRECTORY": self.temp_dir.path,
                "GPS_PORT": "/dev/serial0",
                "GPS_BAUDRATE": "9600",
            },
        ):

            expected_data = None

            gps_reader = GPSReader()
            data = gps_reader.get_geolocation()

            mock_port.return_value.inWaiting.assert_called()
            mock_port.return_value.readline.assert_called()

            self.assertEqual(expected_data, data)

    def test_get_location_invalid_data(self, mock_port):

        mock_port.return_value.inWaiting.return_value = 1
        mock_port.return_value.readline.return_value = (
            "b'$GPRMC,194509.000,V,4042.6142,N,07400.4168,W,2.03,221.11,160412,,,A*77"
        )

        with patch.dict(
            os.environ,
            {
                "GPS_LOG_FILE": "logger.txt",
                "LOG_DIRECTORY": self.temp_dir.path,
                "GPS_PORT": "/dev/serial0",
                "GPS_BAUDRATE": "9600",
            },
        ):

            expected_data = None

            gps_reader = GPSReader()
            data = gps_reader.get_geolocation()

            mock_port.return_value.inWaiting.assert_called()
            mock_port.return_value.readline.assert_called()

            self.assertEqual(expected_data, data)

    @patch("hardware.gpsPi.gps_reader.date_str_with_current_timezone")
    def test_get_speed_in_mph(self, mock_date, mock_port):

        mock_port.return_value.inWaiting.return_value = 1
        mock_port.return_value.readline.return_value = (
            "b'$GPRMC,194509.000,A,4042.6142,N,07400.4168,W,2.03,221.11,160412,,,A*77"
        )
        mock_date.return_value = "example date"

        with patch.dict(
            os.environ,
            {
                "GPS_LOG_FILE": "logger.txt",
                "LOG_DIRECTORY": self.temp_dir.path,
                "GPS_PORT": "/dev/serial0",
                "GPS_BAUDRATE": "9600",
            },
        ):

            speed_in_mph = 2.03 * 1.151

            expected_data = {}
            expected_data["sensor_id"] = 1
            expected_data["values"] = {
                "speed": speed_in_mph,
            }
            expected_data["date"] = "example date"

            gps_reader = GPSReader()
            data = gps_reader.get_speed_mph()

            mock_port.return_value.inWaiting.assert_called()
            mock_port.return_value.readline.assert_called()

            self.assertEqual(expected_data, data)

    def test_get_speed_in_mph_invalid_data(self, mock_port):

        mock_port.return_value.inWaiting.return_value = 1
        mock_port.return_value.readline.return_value = (
            "b'$GP,194509.000,A,4042.6142,N,07400.4168,W,2.03,221.11,160412,,,A*77"
        )

        with patch.dict(
            os.environ,
            {
                "GPS_LOG_FILE": "logger.txt",
                "LOG_DIRECTORY": self.temp_dir.path,
                "GPS_PORT": "/dev/serial0",
                "GPS_BAUDRATE": "9600",
            },
        ):

            expected_data = None

            gps_reader = GPSReader()
            data = gps_reader.get_speed_mph()

            mock_port.return_value.inWaiting.assert_called()
            mock_port.return_value.readline.assert_called()

            self.assertEqual(expected_data, data)
示例#60
0
class test_pymca(unittest.TestCase):
    def setUp(self):
        self.dir = TempDirectory()
        self.energy = [7.5, 8]

    def tearDown(self):
        self.dir.cleanup()

    @unittest.skipIf(
        xrfdetectors.compoundfromname.xraylib is None, "xraylib not installed"
    )
    def test_loadcfg(self):
        cfgfile = os.path.join(self.dir.path, "mca.cfg")

        h1 = xrf_setup.simple(energy=self.energy)
        h1.savepymca(cfgfile)

        source = xraysources.factory("synchrotron")
        detector = xrfdetectors.factory("XRFDetector", ehole=3.8)
        geometry = xrfgeometries.factory(
            "LinearXRFGeometry",
            detector=detector,
            source=source,
            zerodistance=0,
            detectorposition=0,
            positionunits="mm",
        )
        sample = multilayer.Multilayer(geometry=geometry)

        h2 = pymca.PymcaHandle(sample=sample)
        h2.loadfrompymca(cfgfile)
        np.testing.assert_allclose(h1.mca(), h2.mca())

    @unittest.skipIf(
        xrfdetectors.compoundfromname.xraylib is None, "xraylib not installed"
    )
    def test_rates(self):
        h = xrf_setup.simple(energy=self.energy, escape=0, snip=False, continuum=1)

        if False:
            path = "/data/id21/inhouse/wout/tmp/pymcatst"
            cfgfile = os.path.join(path, "mca_mixture.cfg")
            mcafile = os.path.join(path, "spectrum.mca")
            h.savepymca(cfgfile)
            h.savemca(mcafile, func=lambda x: x + 1)

        y = h.mca()
        y += 1

        # Fit data
        h.setdata(y)
        h.configurepymca()
        fitresult = h.fit()

        # Mass fractions are calculated as follows:
        #   grouparea = flux.time.grouprate
        #   grouprate = solidangle/(4.pi).sum_l[massfrac_l.grouprate_l]  where l loops over the layers
        #
        #   massfrac_l = self.mcafit._fluoRates[layer][element]["rates"][group]      where layer in 1,...,n
        #   grouprate_l = self.mcafit._fluoRates[layer][element]["mass fraction"]
        #   sum_l[massfrac_l.grouprate_l] = self.mcafit._fluoRates[0][element]["rates"][group]*
        #                                   self.mcafit._fluoRates[0][element]["mass fraction"]
        #
        # When element in one layer:
        #   massfrac = area/(flux.time.solidangle/(4.pi).grouprate_l)
        #
        # When element in more than one layer:
        #   grouprate_avg = solidangle/(4.pi).massfrac_avg.sum_l[grouprate_l]
        #
        # When element in more than one layer (per layer as if all intensity came from that layer?):
        #   massfrac_l = grouparea/(flux.time.solidangle/(4.pi).grouprate_l)

        grouprates = h.xraygrouprates(scattering=False, method="fisx")
        safrac = h.sample.geometry.solidangle / (4 * np.pi)
        np.testing.assert_allclose(safrac, h._pymcainternals_solidanglefrac())

        for group in fitresult["fitareas"]:
            if not isinstance(group, xrayspectrum.FluoZLine):
                continue
            element, linegroup = group.element, group.group

            grouprate_avg = h.mcafit._fluoRates[0][element]["rates"][
                "{} xrays".format(linegroup)
            ]
            grouprate_avg *= safrac
            massfrac_avg = h.mcafit._fluoRates[0][element]["mass fraction"]

            grouprate = 0.0
            grouprate_avg2 = 0.0
            massfrac_avg2 = 0.0
            npresent = 0
            for j in range(h.sample.nlayers):
                i = j + 1
                if element in h.mcafit._fluoRates[i]:
                    massfrac_l = h.mcafit._fluoRates[i][element]["mass fraction"]
                    grouprate_l = h.mcafit._fluoRates[i][element]["rates"][
                        "{} xrays".format(linegroup)
                    ]
                    grouprate += massfrac_l * grouprate_l
                    grouprate_avg2 += grouprate_l
                    # massfrac_avg2 = max(massfrac_avg2,massfrac_l)
                    massfrac_avg2 = massfrac_l  # just the last one?
                    npresent += 1
            grouprate *= safrac
            grouprate_avg2 *= safrac
            # TODO: something wrong with mixtures of different elements (fisx vs Elements)
            if group in grouprates:
                # 1% error fisx vs. Elements?
                np.testing.assert_allclose(grouprate, grouprates[group], rtol=1e-2)

            np.testing.assert_allclose(grouprate, fitresult["rates"][group])

            if npresent == 1:
                np.testing.assert_allclose(grouprate_avg * massfrac_avg, grouprate)
            else:
                np.testing.assert_allclose(massfrac_avg, massfrac_avg2)
                np.testing.assert_allclose(grouprate_avg, grouprate_avg2)

            np.testing.assert_allclose(
                fitresult["massfractions"][group],
                fitresult["fitareas"][group] / (h.I0 * grouprate_avg),
            )

            for j in range(h.sample.nlayers):
                i = j + 1
                if element in h.mcafit._fluoRates[i]:
                    grouprate_l = h.mcafit._fluoRates[i][element]["rates"][
                        "{} xrays".format(linegroup)
                    ]
                    grouprate = grouprate_l * safrac
                    np.testing.assert_allclose(
                        fitresult["lmassfractions"][j][group],
                        fitresult["fitareas"][group] / (h.I0 * grouprate),
                    )

        # Plot
        plt.plot(fitresult["energy"], fitresult["y"], label="data")
        plt.plot(fitresult["energy"], fitresult["yfit"], label="pymca")

        spectrum = h.xrayspectrum()
        spectrum.plot(
            fluxtime=h.I0,
            histogram=True,
            ylog=False,
            decompose=False,
            backfunc=lambda x: 1,
        )

        ax = plt.gca()
        ax.set_ylim(ymin=np.nanmin(y[np.nonzero(y)]))
        ax.set_xlabel("Energy (keV)")
        ax.set_ylabel("Intensity (cts)")
        plt.legend(loc="best")
        # plt.show()

    @unittest.skipIf(
        xrfdetectors.compoundfromname.xraylib is None, "xraylib not installed"
    )
    def test_spectrum(self):
        h = xrf_setup.complex(
            energy=self.energy,
            escape=0,
            flux=1e10,
            time=1,
            scatter=np.zeros_like(self.energy),
            linear=1,
            emin=2,
            emax=7.4,
        )
        h.sample.geometry.detector.bltail = False
        h.sample.geometry.detector.bstail = False
        h.sample.geometry.detector.bstep = False
        y = h.mca()

        # Prepare fit
        h.setdata(y)
        h.configurepymca()

        # Force config
        config = h.mcafit.getConfiguration()
        config["fit"]["stripflag"] = 0
        h.mcafit.configure(config)

        # Fit
        # h.fitgui(loadfromfit=False)
        fitresult = h.fit(loadfromfit=False)

        # Get fundamental MCA and fitted MCA
        spectrum = h.xrayspectrum(method="fisx", scattering=False)
        x, ysum, ylabel = spectrum.sumspectrum(fluxtime=h.I0, histogram=True)
        ypymca = fitresult["interpol_energy"](fitresult["ymatrix"])(x)

        # TODO: Doesn't work due to peak rejection, add it the xrayspectrum
        # np.testing.assert_allclose(ysum,ypymca,rtol=1e-2)

        # Plot
        # x,ygroup,ylabel,names = spectrum.linespectra(fluxtime=h.I0,histogram=True)
        # for name,y in zip(names,ygroup.T):
        #    if "Ce-L3" not in name: # pymca is cutting small peaks
        #        continue
        #    plt.plot(x,y,label=name)

        plt.plot(x, ysum, label="fisx", linewidth=2)
        plt.plot(x, ypymca, label="pymca", linewidth=2)
        plt.legend()
        ax = plt.gca()
        ax.set_yscale("log", basey=10)
        plt.ylim([0.001, max(ysum)])