Esempio n. 1
0
    def test_can_accept(self):
        inputs_metadata = metadata.DataMetadata({
            'schema':
            metadata.CONTAINER_SCHEMA_VERSION,
            'structural_type':
            container.ndarray,
        })

        self.assertFalse(
            IncrementPrimitive.can_accept(method_name='produce',
                                          arguments={
                                              'inputs': inputs_metadata,
                                          }))

        inputs_metadata.update((), {
            'dimension': {
                'length': 2,
            },
        })

        self.assertFalse(
            IncrementPrimitive.can_accept(method_name='produce',
                                          arguments={
                                              'inputs': inputs_metadata,
                                          }))

        inputs_metadata.update((metadata.ALL_ELEMENTS, ), {
            'dimension': {
                'length': 2,
            },
        })

        self.assertFalse(
            IncrementPrimitive.can_accept(method_name='produce',
                                          arguments={
                                              'inputs': inputs_metadata,
                                          }))

        inputs_metadata.update((metadata.ALL_ELEMENTS, metadata.ALL_ELEMENTS),
                               {
                                   'structural_type': str,
                               })

        self.assertFalse(
            IncrementPrimitive.can_accept(method_name='produce',
                                          arguments={
                                              'inputs': inputs_metadata,
                                          }))

        inputs_metadata.update((metadata.ALL_ELEMENTS, metadata.ALL_ELEMENTS),
                               {
                                   'structural_type': float,
                               })

        self.assertFalse(
            IncrementPrimitive.can_accept(method_name='produce',
                                          arguments={
                                              'inputs': inputs_metadata,
                                          }))
Esempio n. 2
0
class Hyperparams(hyperparams.Hyperparams):
    primitive = hyperparams.Hyperparameter[base.PrimitiveBase](
        default=IncrementPrimitive(
            hyperparams=IncrementPrimitiveHyperparams.defaults()),
        semantic_types=[
            'https://metadata.datadrivendiscovery.org/types/ControlParameter'
        ],
        description=
        'The primitive instance to be passed to PrimitiveHyperparamPrimitive')
Esempio n. 3
0
    def test_hyperparameter(self):
        hyperparams_class = IncrementPrimitive.metadata.query(
        )['primitive_code']['class_type_arguments']['Hyperparams']

        primitive = IncrementPrimitive(hyperparams=hyperparams_class(amount=2))

        inputs = container.ndarray(
            [[1, 2, 3, 4], [5, 6, 7, 8]], {
                'schema': metadata.CONTAINER_SCHEMA_VERSION,
                'structural_type': container.ndarray,
                'dimension': {
                    'length': 2,
                },
            })
        inputs.metadata = inputs.metadata.update((metadata.ALL_ELEMENTS, ), {
            'dimension': {
                'length': 4,
            },
        })
        inputs.metadata = inputs.metadata.update(
            (metadata.ALL_ELEMENTS, metadata.ALL_ELEMENTS), {
                'structural_type': inputs.dtype.type,
            })

        call_metadata = self.call_primitive(primitive,
                                            'produce',
                                            inputs=inputs)

        self.assertTrue(
            numpy.array_equal(call_metadata.value,
                              container.ndarray([[3, 4, 5, 6], [7, 8, 9,
                                                                10]])))
        self.assertEqual(call_metadata.has_finished, True)
        self.assertEqual(call_metadata.iterations_done, None)

        self.assertIs(call_metadata.value.metadata.for_value,
                      call_metadata.value)
        self.assertEqual(
            call_metadata.value.metadata.query(())['dimension']['length'], 2)
        self.assertEqual(
            call_metadata.value.metadata.query(
                (metadata.ALL_ELEMENTS, ))['dimension']['length'], 4)
        self.assertEqual(
            call_metadata.value.metadata.query(
                (metadata.ALL_ELEMENTS,
                 metadata.ALL_ELEMENTS))['structural_type'], numpy.int64)
Esempio n. 4
0
    def test_basic(self):
        hyperparams_class = IncrementPrimitive.metadata.get_hyperparams()

        primitive = IncrementPrimitive(
            hyperparams=hyperparams_class.defaults())

        inputs = container.DataFrame(
            [[1, 2, 3, 4], [5, 6, 7, 8]],
            {
                # Custom metadata.
                'foo': 'bar',
            },
            generate_metadata=True)

        call_metadata = self.call_primitive(primitive,
                                            'produce',
                                            inputs=inputs)

        self.assertTrue(
            call_metadata.value.equals(
                container.DataFrame(
                    [[2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]],
                    generate_metadata=True)))
        self.assertEqual(call_metadata.has_finished, True)
        self.assertEqual(call_metadata.iterations_done, None)

        self.assertEqual(
            call_metadata.value.metadata.query(())['dimension']['length'], 2)
        self.assertEqual(
            call_metadata.value.metadata.query(
                (base.ALL_ELEMENTS, ))['dimension']['length'], 4)
        self.assertEqual(
            call_metadata.value.metadata.query(
                (base.ALL_ELEMENTS, 0))['structural_type'], numpy.float64)
        self.assertEqual(
            call_metadata.value.metadata.query(
                (base.ALL_ELEMENTS, 1))['structural_type'], numpy.float64)
        self.assertEqual(
            call_metadata.value.metadata.query(
                (base.ALL_ELEMENTS, 2))['structural_type'], numpy.float64)
        self.assertEqual(
            call_metadata.value.metadata.query(
                (base.ALL_ELEMENTS, 3))['structural_type'], numpy.float64)
        self.assertEqual(
            call_metadata.value.metadata.query(()).get('foo', None), 'bar')