def test_recreation(self): hyperparams_class = MonomialPrimitive.metadata.get_hyperparams() primitive = MonomialPrimitive(hyperparams=hyperparams_class(bias=1)) inputs = container.List([1, 2, 3, 4, 5, 6], generate_metadata=True) outputs = container.List([2, 4, 6, 8, 10, 12], generate_metadata=True) self.call_primitive(primitive, 'set_training_data', inputs=inputs, outputs=outputs) call_metadata = self.call_primitive(primitive, 'fit') self.assertEqual(call_metadata.has_finished, True) self.assertEqual(call_metadata.iterations_done, None) params = self.call_primitive(primitive, 'get_params') pickled_params = pickle.dumps(params) unpickled_params = pickle.loads(pickled_params) self.assertEqual(params, unpickled_params) pickled_hyperparams = pickle.dumps(primitive.hyperparams) unpickled_hyperparams = pickle.loads(pickled_hyperparams) self.assertEqual(primitive.hyperparams, unpickled_hyperparams) primitive = MonomialPrimitive(hyperparams=unpickled_hyperparams) self.call_primitive(primitive, 'set_params', params=unpickled_params) inputs = container.List([10, 20, 30], generate_metadata=True) call_metadata = self.call_primitive(primitive, 'produce', inputs=inputs) self.assertSequenceEqual(call_metadata.value, [21, 41, 61]) self.assertEqual(call_metadata.has_finished, True) self.assertEqual(call_metadata.iterations_done, None) self.assertEqual( call_metadata.value.metadata.query(())['dimension']['length'], 3) self.assertEqual( call_metadata.value.metadata.query( (base.ALL_ELEMENTS, ))['structural_type'], float)
def test_basic(self): hyperparams_class = MonomialPrimitive.metadata.get_hyperparams() primitive = MonomialPrimitive(hyperparams=hyperparams_class.defaults()) inputs = container.List([1, 2, 3, 4, 5, 6], generate_metadata=True) outputs = container.List([2, 4, 6, 8, 10, 12], generate_metadata=True) self.call_primitive(primitive, 'set_training_data', inputs=inputs, outputs=outputs) call_metadata = self.call_primitive(primitive, 'fit') self.assertEqual(call_metadata.has_finished, True) self.assertEqual(call_metadata.iterations_done, None) inputs = container.List([10, 20, 30], generate_metadata=True) call_metadata = self.call_primitive(primitive, 'produce', inputs=inputs) self.assertSequenceEqual(call_metadata.value, [20, 40, 60]) self.assertEqual(call_metadata.has_finished, True) self.assertEqual(call_metadata.iterations_done, None) self.assertEqual( call_metadata.value.metadata.query(())['dimension']['length'], 3) self.assertEqual( call_metadata.value.metadata.query( (base.ALL_ELEMENTS, ))['structural_type'], float) call_metadata = primitive.multi_produce(produce_methods=('produce', ), inputs=inputs) self.assertEqual(len(call_metadata.values), 1) self.assertSequenceEqual(call_metadata.values['produce'], [20, 40, 60]) self.assertEqual(call_metadata.has_finished, True) self.assertEqual(call_metadata.iterations_done, None)
def test_hyperparameter(self): hyperparams_class = MonomialPrimitive.metadata.query( )['primitive_code']['class_type_arguments']['Hyperparams'] primitive = MonomialPrimitive(hyperparams=hyperparams_class(bias=1)) # TODO: Add dimension metadata. inputs = container.List[float]( [1, 2, 3, 4, 5, 6], { 'schema': metadata.CONTAINER_SCHEMA_VERSION, 'structural_type': container.List[float], }) # TODO: Add dimension metadata. outputs = container.List[float]( [2, 4, 6, 8, 10, 12], { 'schema': metadata.CONTAINER_SCHEMA_VERSION, 'structural_type': container.List[float], }) self.call_primitive(primitive, 'set_training_data', inputs=inputs, outputs=outputs) call_metadata = self.call_primitive(primitive, 'fit') self.assertEqual(call_metadata.has_finished, True) self.assertEqual(call_metadata.iterations_done, None) # TODO: Add dimension metadata. inputs = container.List[float]( [10, 20, 30], { 'schema': metadata.CONTAINER_SCHEMA_VERSION, 'structural_type': container.List[float], }) call_metadata = self.call_primitive(primitive, 'produce', inputs=inputs) self.assertSequenceEqual(call_metadata.value, [21, 41, 61]) self.assertEqual(call_metadata.has_finished, True) self.assertEqual(call_metadata.iterations_done, None) self.assertIs(call_metadata.value.metadata.for_value, call_metadata.value) self.assertEqual( call_metadata.value.metadata.query(())['dimension']['length'], 3) self.assertEqual( call_metadata.value.metadata.query( (metadata.ALL_ELEMENTS, ))['structural_type'], float)